喜欢谈论痛苦的往往是|

smiling&weeping

园龄:1年8个月粉丝:2关注:0

从零开始搭建扩散模型

Smiling & Weeping

 

                    ---- 倘若思念没有国籍,

                 那么此刻,我的心脏也许是亚洲最动荡的岛屿。

1. 环境准备(一些前提...杂七杂八的东西)

# 环境准备
%pip install diffusers

登录huggingface_hub社区

# Code to log in to the Hugging Face Hub, needed for sharing models
# Make sure you use a token with WRITE access
from huggingface_hub import notebook_login

notebook_login()

 

复制代码
import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from diffusers import DDPMScheduler, UNet2DModel
from matplotlib import pyplot as plt
from datasets import load_dataset

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device
复制代码

 

我使用的数据集是huggingface的lowres/anime(3x256x256)数据集,大家可以使用MNIST(1x28x28)数据集,因为从BasicUNet(手写parameters30万左右)到使用UNet2DModel(parameters170万左右)训练时,会需要较多的GPU算力资源

使用一个使用一个非常小的经典数据集mnist来进行测试。

dataset = torchvision.datasets.MNIST(root="./data/mnist/", train=True, download=True, 
                                     transform=torchvision.transforms.ToTensor())
train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
x, y = next(iter(train_dataloader))
print('Input shape:', x.shape)
print('Labels:', y)
plt.imshow(torchvision.utils.make_grid(x)[0], cmap='Greys');

加载所需数据集

复制代码
dataset_name = 'lowres/anime'
dataset = load_dataset(dataset_name, split='train')

from torchvision import transforms

image_size = 256
batch_size = 8

preprocess = transforms.Compose(
[
    transforms.Resize((image_size, image_size)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),
])

def transform(example):
    images = [preprocess(image.convert("RGB")) for image in example['image']]
    return {"images": images}

dataset.set_transform(transform)
train_dataloader = torch.utils.data.DataLoader(
dataset, batch_size=batch_size, shuffle=True)
复制代码

 

3.2 扩散模型--退化过程

退化过程就是添加噪声的过程

# 模拟退化过程----添加噪声
def corrupt(x, amount):
    noise = torch.rand_like(x)
    amount = amount.view(-1, 1, 1, 1)
    return x*(1-amount) + noise*amount

绘制数据集

复制代码
x = next(iter(train_dataloader))

# 绘制输入数据
fig, axs = plt.subplots(2, 1, figsize=(7, 5))
axs[0].set_title("Input data")
axs[0].imshow(torchvision.utils.make_grid(x["images"]).permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5)

# 加入噪声
amount = torch.linspace(0, 1, x["images"].shape[0])
noised_x = corrupt(x["images"], amount)

# 绘制加噪版本的图像
axs[1].set_title('Corrupted data')
axs[1].imshow(torchvision.utils.make_grid(noised_x).permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5)
复制代码

3.3 扩散模型的训练

模型接收一个单通道图像,并通过下行路径上的三个卷积层(图和代码中的down_layers)和上行路径上的3个卷积层,在下行和上行层之间具有残差连接。使用最大池化层进行下采样和nn.Upsample用于上采样。

复制代码
class BasicUNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super().__init__()
        self.down_layers = torch.nn.ModuleList([
            nn.Conv2d(in_channels, 32, kernel_size=5, padding=2),
            nn.Conv2d(32, 64, kernel_size=5, padding=2),
            nn.Conv2d(64, 64, kernel_size=5, padding=2),
        ])
        self.up_layers = torch.nn.ModuleList([
            nn.Conv2d(64, 64, kernel_size=5, padding=2),
            nn.Conv2d(64, 32, kernel_size=5, padding=2),
            nn.Conv2d(32, out_channels, kernel_size=5, padding=2),
        ])
        self.act = nn.SiLU()
        self.downscale = nn.MaxPool2d(2)
        self.upscale = nn.Upsample(scale_factor=2)
    
    def forward(self, x):
        h = []
        for i, l in enumerate(self.down_layers):
            # 通过运算层与激活层
            x = self.act(l(x))
            if i < 2:
                # 排列供残差链接使用的数据
                h.append(x)
                # 链接下采样
                x = self.downscale(x)
                
        for i, l in enumerate(self.up_layers):
            if i > 0:
                # 链接上采样
                x = self.upscale(x)
                x += h.pop()
            x = self.act(l(x))
    
        return x
复制代码
net = BasicUNet()
x = torch.rand(8, 3, 256, 256)
# 训练数据加载器
batch_size = 128 train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
net = BasicUNet()
net.to(device)

loss_fn = nn.MSELoss()

opt = torch.optim.Adam(net.parameters(), lr=1e-3)
复制代码
n_epoches = 35

losses = []

for epoch in range(n_epoches):
    for x1 in train_dataloader:
        # 得到数据并准备退化, 干净数据
        x = x1['images']
        x = x.to(device)
        # 随机噪声
        noise_amount = torch.rand(x.shape[0]).to(device)
        # 退化过程, 噪声数据
        noise_x = corrupt(x, noise_amount)
        
        # 得到预测结果
        x_pred = net(noise_x)
        
        # 计算损失函数
        loss = loss_fn(x_pred, x)
        
        # 反向传播并更新参数
        opt.zero_grad()
        loss.backward()
        opt.step()
        
        losses.append(loss.item())
        
    avg_loss = sum(losses[-len(train_dataloader):]) / len(train_dataloader)
    print(f"Finished epoch {epoch}. Average loss for this epoch: {avg_loss:05f}")
复制代码

 

plt.plot(losses)

 对比输入数据、退化数据、预测数据:(大家训练完可以看看效果)

复制代码
x = next(iter(train_dataloader))
x = x['images']
# 只取前8个数
x = x[:8]

amount = torch.linspace(0, 1, x.shape[0])
noised_x = corrupt(x, amount)

with torch.no_grad():
    preds = net(noised_x.to(device)).detach().cpu()

# 绘制对比图
fig, axs = plt.subplots(3, 1, figsize=(7, 5))
axs[0].set_title("Input data")
axs[0].imshow(torchvision.utils.make_grid(x).permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5)
axs[1].set_title("Corrupted data")
axs[1].imshow(torchvision.utils.make_grid(noised_x).permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5)
axs[2].set_title("Network Preditions")
axs[2].imshow(torchvision.utils.make_grid(preds).permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5)
复制代码

 

3.4 扩散模型的采样过程

3.4.1 采样过程

采样过程的方案:从完全随机的噪声开始,先检查一下模型的预测结果,然后只朝着预测方向移动一小部分

复制代码
def sample_with_step(x, n_steps):
    step_history = [x.detach().cpu()]
    pred_output_history = []

    for i in range(n_steps):
        with torch.no_grad(): 
            pred = net(x) 
        # 将模型输出保存下来    
        pred_output_history.append(pred.detach().cpu())
        # 朝着预测方向移动的因子
        mix_factor = 1/(n_steps - i) 
        x = x*(1-mix_factor) + pred*mix_factor 
        step_history.append(x.detach().cpu()) 
    
    return x, step_history, pred_output_history
复制代码
n_steps = 5
# 完全随机的值开始
x = torch.rand(8, 3, 256, 256).to(device)
x, step_history, pred_output_history = sample_with_step(x, n_steps)

绘图自己配置

3.4.2 UNet2DModel

UNet2DModel与DDPM对比:

  • UNet2DModel比BasicUNet更先进。
  • 退化过程的处理方式不同。
  • 训练目标不同,包括预测噪声而不是去噪图像。
  • UNet2DModel模型通过调节时间步来调节噪声量, 其中t作为一个额外参数传入前向过程中。

Diffusers库中的UNet2DModel模型比BasicUNet模型有如下改进:

  • GroupNorm层对每个模块的输入进行了组标准化(group normalization)。
  • Dropout层能使训练更平滑。
  • 每个块有多个ResNet层(如果layers_per_block未设置为1)。
  • 引入了注意力机制(通常仅用于输入分辨率较低的blocks)。
  • 可以对时间步进行调节。
  • 具有可学习参数的上采样模块和下采样模块。
复制代码
net = UNet2DModel(
    sample_size=28,           # 目标图像的分辨率
    in_channels=3,            
    out_channels=3,           
    layers_per_block=2,       # 每一个UNet块中的ResNet层数
    block_out_channels=(32, 64, 64), 
    down_block_types=( 
        "DownBlock2D",        # 下采样模块
        "AttnDownBlock2D",    # 带有空域维度的self-att的ResNet下采样模块
        "AttnDownBlock2D",
    ), 
    up_block_types=(
        "AttnUpBlock2D", 
        "AttnUpBlock2D",      # 带有空域维度的self-att的ResNet上采样模块
        "UpBlock2D",          # 上采样模块
      ),
)
复制代码
复制代码
# 训练数据加载器
batch_size = 128
train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

n_epochs = 3

net.to(device)

loss_fn = nn.MSELoss()

opt = torch.optim.Adam(net.parameters(), lr=1e-3) 

losses = []

# 开始训练
for epoch in range(n_epochs):

    for x, y in train_dataloader:
        # 得到数据并准备退化
        x = x.to(device)
        # 随机噪声
        noise_amount = torch.rand(x.shape[0]).to(device) 
        # 退化过程
        noisy_x = corrupt(x, noise_amount) 
        
        # 得到预测结果
        pred = net(noisy_x, 0).sample
        
        # 计算损失值
        loss = loss_fn(pred, x) 
        
        # 反向传播并更新参数
        opt.zero_grad()
        loss.backward()
        opt.step()

        losses.append(loss.item())
    
    # 输出损失的均值
    avg_loss = sum(losses[-len(train_dataloader):]) / len(train_dataloader)
    print(f'Finished epoch {epoch}. Average loss for this epoch: {avg_loss:05f}')
复制代码

 

复制代码
def sample(x, n_steps):
    
    for i in range(n_steps):
        noise_amount = torch.ones((x.shape[0], )).to(device)*(1-(i/n_steps))
        with torch.no_grad():
            pred = net(x, 0).sample
        mix_factor = 1/(n_steps - i+1)
        x = x*(1-mix_factor) + pred*mix_factor
        
    return x
复制代码

 

复制代码
net.load_state_dict(torch.load('/path/to/model')['model'])
net = net.to(device)

from diffusers import DDPMScheduler, UNet2DModel

noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

fig, axs = plt.subplots(3, 1, figsize=(16, 10))
xb = next(iter(train_dataloader))
xb = xb['images'].to(device)

# Show clean inputs
axs[0].imshow(torchvision.utils.make_grid(xb[:8]).permute(1, 2, 0).cpu().clip(-1, 1)*0.5+0.5)
axs[0].set_title('Clean X')

# Add noise with inputs
timesteps = torch.linspace(0, 999, 8).long().to(device)
noise = torch.randn_like(xb)
noisy_xb = noise_scheduler.add_noise(xb, noise, timesteps)

# Show noisy version (with and without clipping)
axs[1].imshow(torchvision.utils.make_grid(noisy_xb[:8]).permute(1, 2, 0).cpu().clip(-1, 1)*0.5+0.5)
axs[1].set_title('Noisy X clip to (-1, 1)')

pred_x = sample(noisy_xb[:8], 10)
axs[2].imshow(torchvision.utils.make_grid(pred_x).permute(1, 2, 0).cpu().clip(-1, 1)*0.5+0.5)
axs[2].set_title("Preidtion X")
复制代码

 

文章到此结束,我们下次再见

 可許我辞忧(๑′ᴗ‵๑)I Lᵒᵛᵉᵧₒᵤ❤

本文作者:Smiling-Weeping

本文链接:https://www.cnblogs.com/smiling-weeping-zhr/p/18022225

版权声明:本作品采用知识共享署名-非商业性使用-禁止演绎 2.5 中国大陆许可协议进行许可。

posted @   smiling&weeping  阅读(141)  评论(0编辑  收藏  举报
点击右上角即可分享
微信分享提示
评论
收藏
关注
推荐
深色
回顶
收起