从零开始搭建扩散模型
Smiling & Weeping
---- 倘若思念没有国籍,
那么此刻,我的心脏也许是亚洲最动荡的岛屿。
1. 环境准备(一些前提...杂七杂八的东西)
# 环境准备 %pip install diffusers
登录huggingface_hub社区
# Code to log in to the Hugging Face Hub, needed for sharing models # Make sure you use a token with WRITE access from huggingface_hub import notebook_login notebook_login()
import torch import torchvision from torch import nn from torch.nn import functional as F from torch.utils.data import DataLoader from diffusers import DDPMScheduler, UNet2DModel from matplotlib import pyplot as plt from datasets import load_dataset device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device
我使用的数据集是huggingface的lowres/anime(3x256x256)数据集,大家可以使用MNIST(1x28x28)数据集,因为从BasicUNet(手写parameters30万左右)到使用UNet2DModel(parameters170万左右)训练时,会需要较多的GPU算力资源
使用一个使用一个非常小的经典数据集mnist来进行测试。
dataset = torchvision.datasets.MNIST(root="./data/mnist/", train=True, download=True, transform=torchvision.transforms.ToTensor()) train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True) x, y = next(iter(train_dataloader)) print('Input shape:', x.shape) print('Labels:', y) plt.imshow(torchvision.utils.make_grid(x)[0], cmap='Greys');
加载所需数据集
dataset_name = 'lowres/anime' dataset = load_dataset(dataset_name, split='train') from torchvision import transforms image_size = 256 batch_size = 8 preprocess = transforms.Compose( [ transforms.Resize((image_size, image_size)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ]) def transform(example): images = [preprocess(image.convert("RGB")) for image in example['image']] return {"images": images} dataset.set_transform(transform) train_dataloader = torch.utils.data.DataLoader( dataset, batch_size=batch_size, shuffle=True)
3.2 扩散模型--退化过程
退化过程就是添加噪声的过程
# 模拟退化过程----添加噪声 def corrupt(x, amount): noise = torch.rand_like(x) amount = amount.view(-1, 1, 1, 1) return x*(1-amount) + noise*amount
绘制数据集
x = next(iter(train_dataloader)) # 绘制输入数据 fig, axs = plt.subplots(2, 1, figsize=(7, 5)) axs[0].set_title("Input data") axs[0].imshow(torchvision.utils.make_grid(x["images"]).permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5) # 加入噪声 amount = torch.linspace(0, 1, x["images"].shape[0]) noised_x = corrupt(x["images"], amount) # 绘制加噪版本的图像 axs[1].set_title('Corrupted data') axs[1].imshow(torchvision.utils.make_grid(noised_x).permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5)
3.3 扩散模型的训练
模型接收一个单通道图像,并通过下行路径上的三个卷积层(图和代码中的down_layers
)和上行路径上的3个卷积层,在下行和上行层之间具有残差连接。使用最大池化层进行下采样和nn.Upsample
用于上采样。
class BasicUNet(nn.Module): def __init__(self, in_channels=3, out_channels=3): super().__init__() self.down_layers = torch.nn.ModuleList([ nn.Conv2d(in_channels, 32, kernel_size=5, padding=2), nn.Conv2d(32, 64, kernel_size=5, padding=2), nn.Conv2d(64, 64, kernel_size=5, padding=2), ]) self.up_layers = torch.nn.ModuleList([ nn.Conv2d(64, 64, kernel_size=5, padding=2), nn.Conv2d(64, 32, kernel_size=5, padding=2), nn.Conv2d(32, out_channels, kernel_size=5, padding=2), ]) self.act = nn.SiLU() self.downscale = nn.MaxPool2d(2) self.upscale = nn.Upsample(scale_factor=2) def forward(self, x): h = [] for i, l in enumerate(self.down_layers): # 通过运算层与激活层 x = self.act(l(x)) if i < 2: # 排列供残差链接使用的数据 h.append(x) # 链接下采样 x = self.downscale(x) for i, l in enumerate(self.up_layers): if i > 0: # 链接上采样 x = self.upscale(x) x += h.pop() x = self.act(l(x)) return x
net = BasicUNet()
x = torch.rand(8, 3, 256, 256)
# 训练数据加载器
batch_size = 128 train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
net = BasicUNet() net.to(device) loss_fn = nn.MSELoss() opt = torch.optim.Adam(net.parameters(), lr=1e-3)
n_epoches = 35 losses = [] for epoch in range(n_epoches): for x1 in train_dataloader: # 得到数据并准备退化, 干净数据 x = x1['images'] x = x.to(device) # 随机噪声 noise_amount = torch.rand(x.shape[0]).to(device) # 退化过程, 噪声数据 noise_x = corrupt(x, noise_amount) # 得到预测结果 x_pred = net(noise_x) # 计算损失函数 loss = loss_fn(x_pred, x) # 反向传播并更新参数 opt.zero_grad() loss.backward() opt.step() losses.append(loss.item()) avg_loss = sum(losses[-len(train_dataloader):]) / len(train_dataloader) print(f"Finished epoch {epoch}. Average loss for this epoch: {avg_loss:05f}")
plt.plot(losses)
对比输入数据、退化数据、预测数据:(大家训练完可以看看效果)
x = next(iter(train_dataloader)) x = x['images'] # 只取前8个数 x = x[:8] amount = torch.linspace(0, 1, x.shape[0]) noised_x = corrupt(x, amount) with torch.no_grad(): preds = net(noised_x.to(device)).detach().cpu() # 绘制对比图 fig, axs = plt.subplots(3, 1, figsize=(7, 5)) axs[0].set_title("Input data") axs[0].imshow(torchvision.utils.make_grid(x).permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5) axs[1].set_title("Corrupted data") axs[1].imshow(torchvision.utils.make_grid(noised_x).permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5) axs[2].set_title("Network Preditions") axs[2].imshow(torchvision.utils.make_grid(preds).permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5)
3.4 扩散模型的采样过程
3.4.1 采样过程
采样过程的方案:从完全随机的噪声开始,先检查一下模型的预测结果,然后只朝着预测方向移动一小部分
def sample_with_step(x, n_steps): step_history = [x.detach().cpu()] pred_output_history = [] for i in range(n_steps): with torch.no_grad(): pred = net(x) # 将模型输出保存下来 pred_output_history.append(pred.detach().cpu()) # 朝着预测方向移动的因子 mix_factor = 1/(n_steps - i) x = x*(1-mix_factor) + pred*mix_factor step_history.append(x.detach().cpu()) return x, step_history, pred_output_history
n_steps = 5 # 完全随机的值开始 x = torch.rand(8, 3, 256, 256).to(device) x, step_history, pred_output_history = sample_with_step(x, n_steps)
绘图自己配置
3.4.2 UNet2DModelUNet2DModel与DDPM对比:
- UNet2DModel比BasicUNet更先进。
- 退化过程的处理方式不同。
- 训练目标不同,包括预测噪声而不是去噪图像。
- UNet2DModel模型通过调节时间步来调节噪声量, 其中t作为一个额外参数传入前向过程中。
Diffusers库中的UNet2DModel模型比BasicUNet模型有如下改进:
- GroupNorm层对每个模块的输入进行了组标准化(group normalization)。
- Dropout层能使训练更平滑。
- 每个块有多个ResNet层(如果layers_per_block未设置为1)。
- 引入了注意力机制(通常仅用于输入分辨率较低的blocks)。
- 可以对时间步进行调节。
- 具有可学习参数的上采样模块和下采样模块。
net = UNet2DModel( sample_size=28, # 目标图像的分辨率 in_channels=3, out_channels=3, layers_per_block=2, # 每一个UNet块中的ResNet层数 block_out_channels=(32, 64, 64), down_block_types=( "DownBlock2D", # 下采样模块 "AttnDownBlock2D", # 带有空域维度的self-att的ResNet下采样模块 "AttnDownBlock2D", ), up_block_types=( "AttnUpBlock2D", "AttnUpBlock2D", # 带有空域维度的self-att的ResNet上采样模块 "UpBlock2D", # 上采样模块 ), )
# 训练数据加载器 batch_size = 128 train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) n_epochs = 3 net.to(device) loss_fn = nn.MSELoss() opt = torch.optim.Adam(net.parameters(), lr=1e-3) losses = [] # 开始训练 for epoch in range(n_epochs): for x, y in train_dataloader: # 得到数据并准备退化 x = x.to(device) # 随机噪声 noise_amount = torch.rand(x.shape[0]).to(device) # 退化过程 noisy_x = corrupt(x, noise_amount) # 得到预测结果 pred = net(noisy_x, 0).sample # 计算损失值 loss = loss_fn(pred, x) # 反向传播并更新参数 opt.zero_grad() loss.backward() opt.step() losses.append(loss.item()) # 输出损失的均值 avg_loss = sum(losses[-len(train_dataloader):]) / len(train_dataloader) print(f'Finished epoch {epoch}. Average loss for this epoch: {avg_loss:05f}')
def sample(x, n_steps): for i in range(n_steps): noise_amount = torch.ones((x.shape[0], )).to(device)*(1-(i/n_steps)) with torch.no_grad(): pred = net(x, 0).sample mix_factor = 1/(n_steps - i+1) x = x*(1-mix_factor) + pred*mix_factor return x
net.load_state_dict(torch.load('/path/to/model')['model']) net = net.to(device) from diffusers import DDPMScheduler, UNet2DModel noise_scheduler = DDPMScheduler(num_train_timesteps=1000) train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) fig, axs = plt.subplots(3, 1, figsize=(16, 10)) xb = next(iter(train_dataloader)) xb = xb['images'].to(device) # Show clean inputs axs[0].imshow(torchvision.utils.make_grid(xb[:8]).permute(1, 2, 0).cpu().clip(-1, 1)*0.5+0.5) axs[0].set_title('Clean X') # Add noise with inputs timesteps = torch.linspace(0, 999, 8).long().to(device) noise = torch.randn_like(xb) noisy_xb = noise_scheduler.add_noise(xb, noise, timesteps) # Show noisy version (with and without clipping) axs[1].imshow(torchvision.utils.make_grid(noisy_xb[:8]).permute(1, 2, 0).cpu().clip(-1, 1)*0.5+0.5) axs[1].set_title('Noisy X clip to (-1, 1)') pred_x = sample(noisy_xb[:8], 10) axs[2].imshow(torchvision.utils.make_grid(pred_x).permute(1, 2, 0).cpu().clip(-1, 1)*0.5+0.5) axs[2].set_title("Preidtion X")
文章到此结束,我们下次再见
可許我辞忧(๑′ᴗ‵๑)I Lᵒᵛᵉᵧₒᵤ❤
本文作者:Smiling-Weeping
本文链接:https://www.cnblogs.com/smiling-weeping-zhr/p/18022225
版权声明:本作品采用知识共享署名-非商业性使用-禁止演绎 2.5 中国大陆许可协议进行许可。
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】凌霞软件回馈社区,博客园 & 1Panel & Halo 联合会员上线
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】博客园社区专享云产品让利特惠,阿里云新客6.5折上折
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步