GAN系列2:利用简单的GAN生成手写体图像
目的:
基于pytorch利用GAN生成手写体图像;
系列内容:
一、学习GAN基本架构;
二、生成器和判别器的训练;
三、GAN中生成器和判别器的损失函数;
四、各种应用GAN的架构;
训练判别器:
1)得到真实数据和真实标签(真实标签标记为1);真实标签的长度应该等于batch size的长度;
2)前向传播,将真实的数据传给班别器,得到来自真实数据的真实输出;
3)计算判别器损失从真实的输出和标签中,并且反向传播它;
4)使用生成的数据,通过生成器进行前向传播,计算生成数据的输出和生成数据的损失;反向传播生成数据的损失;通过计算真实数据损失和生成数据损失,计算整体损失;
5)更新判别器的参数;
训练生成器:
1)通过前向传播得到生成器的生成数据;标记为1;
2)通过判别器做前向传播;
3)计算损失并且反向传播;
4)更新并优化生成器参数;
文件结构:
├───input
├───outputs
└───src
vanilla_gan.py
代码实现:
我们将在vanilla_gan.py中实现我们所有的代码;
1) 导入包
1 import torch 2 import torch.nn as nn 3 import torchvision.transforms as transforms 4 import torch.optim as optim 5 import torchvision.datasets as datasets 6 import imageio 7 import numpy as np 8 import matplotlib 9 from torchvision.utils import make_grid, save_image 10 from torch.utils.data import DataLoader 11 from matplotlib import pyplot as plt 12 from tqdm import tqdm 13 matplotlib.style.use('ggplot')
make_grid()和save_image有助于图像的存储;
2)学习参数的定义:
1 # learning parameters 2 batch_size = 512 3 epochs = 200 4 sample_size = 64 # fixed sample size 5 nz = 128 # latent vector size 6 k = 1 # number of steps to apply to the discriminator 7 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
3)数据集的准备
1 transform = transforms.Compose([ 2 transforms.ToTensor(), 3 transforms.Normalize((0.5,),(0.5,)), 4 ]) 5 to_pil_image = transforms.ToPILImage()
Line5将数据转换为PIL图像格式;这是必要的;当我们想存储GAN生成的图像;在存储之前,必须转换为PIL图像格式;
train_data = datasets.MNIST( root='../input/data', train=True, download=True, transform=transform ) train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
定义GAN的判别器和分类器
生成器:
使用简单的线性层;
1 class Generator(nn.Module): 2 def __init__(self, nz): 3 super(Generator, self).__init__() 4 self.nz = nz 5 self.main = nn.Sequential( 6 nn.Linear(self.nz, 256), #输入的特征为128,输出256 7 nn.LeakyReLU(0.2),
8 nn.Linear(256, 512), 9 nn.LeakyReLU(0.2),
10 nn.Linear(512, 1024), 11 nn.LeakyReLU(0.2),
12 nn.Linear(1024, 784), 13 nn.Tanh(), 14 )
15 def forward(self, x): 16 return self.main(x).view(-1, 1, 28, 28)
判别器:
1 class Discriminator(nn.Module): 2 def __init__(self): 3 super(Discriminator, self).__init__() 4 self.n_input = 784 5 self.main = nn.Sequential( 6 nn.Linear(self.n_input, 1024), 7 nn.LeakyReLU(0.2), 8 nn.Dropout(0.3), 9 nn.Linear(1024, 512), 10 nn.LeakyReLU(0.2), 11 nn.Dropout(0.3), 12 nn.Linear(512, 256), 13 nn.LeakyReLU(0.2), 14 nn.Dropout(0.3), 15 nn.Linear(256, 1), 16 nn.Sigmoid(), 17 ) 18 def forward(self, x): 19 x = x.view(-1, 784) 20 return self.main(x)
初始化NN、定义优化器
1 generator = Generator(nz).to(device) 2 discriminator = Discriminator().to(device) 3 print('##### GENERATOR #####') 4 print(generator) 5 print('######################') 6 print('\n##### DISCRIMINATOR #####') 7 print(discriminator) 8 print('######################')
优化器:
1 # optimizers 2 optim_g = optim.Adam(generator.parameters(), lr=0.0002) 3 optim_d = optim.Adam(discriminator.parameters(), lr=0.0002)
损失函数:
1 # loss function 2 criterion = nn.BCELoss()
每次迭代后的损失存储:
1 losses_g = [] # to store generator loss after each epoch 2 losses_d = [] # to store discriminator loss after each epoch 3 images = [] # to store images generatd by the generator
定义一些其他函数:
在GAN训练过程中,我们需要真实图像和生成图像的标记,用于计算损失;
定义两个函数,用于生成1和0
1 # to create real labels (1s) 2 def label_real(size): 3 data = torch.ones(size, 1) 4 return data.to(device) 5 # to create fake labels (0s) 6 def label_fake(size): 7 data = torch.zeros(size, 1) 8 return data.to(device)
在生成器中,我们也需要一个噪音向量,这个向量应该等于nz(128)用于生成图像;
1 # function to create the noise vector 2 def create_noise(sample_size, nz): 3 return torch.randn(sample_size, nz).to(device)
这个函数接受两个参数:sample_size以及nz。
它将返回一个随机向量,后续用于输入生成器中生成假的图像;
最后保存生成的图像
1 # to save the images generated by the generator 2 def save_generator_image(image, path): 3 save_image(image, path)
训练判别器的函数:
1 # function to train the discriminator network 2 def train_discriminator(optimizer, data_real, data_fake): 3 b_size = data_real.size(0) 4 real_label = label_real(b_size) 5 fake_label = label_fake(b_size) 6 optimizer.zero_grad() 7 output_real = discriminator(data_real) 8 loss_real = criterion(output_real, real_label) 9 output_fake = discriminator(data_fake) 10 loss_fake = criterion(output_fake, fake_label) 11 loss_real.backward() 12 loss_fake.backward() 13 optimizer.step() 14 return loss_real + loss_fake
训练GAN
1 # create the noise vector 2 noise = create_noise(sample_size, nz)
1 generator.train() 2 discriminator.train()
开始训练:
1 for epoch in range(epochs): 2 loss_g = 0.0 3 loss_d = 0.0 4 for bi, data in tqdm(enumerate(train_loader), total=int(len(train_data)/train_loader.batch_size)): 5 image, _ = data 6 image = image.to(device) 7 b_size = len(image) 8 # run the discriminator for k number of steps 9 for step in range(k): 10 data_fake = generator(create_noise(b_size, nz)).detach() 11 data_real = image 12 # train the discriminator network 13 loss_d += train_discriminator(optim_d, data_real, data_fake) 14 data_fake = generator(create_noise(b_size, nz)) 15 # train the generator network 16 loss_g += train_generator(optim_g, data_fake) 17 # create the final fake image for the epoch 18 generated_img = generator(noise).cpu().detach() 19 # make the images as grid 20 generated_img = make_grid(generated_img) 21 # save the generated torch tensor models to disk 22 save_generator_image(generated_img, f"../outputs/gen_img{epoch}.png") 23 images.append(generated_img) 24 epoch_loss_g = loss_g / bi # total generator loss for the epoch 25 epoch_loss_d = loss_d / bi # total discriminator loss for the epoch 26 losses_g.append(epoch_loss_g) 27 losses_d.append(epoch_loss_d) 28 29 print(f"Epoch {epoch} of {epochs}") 30 print(f"Generator loss: {epoch_loss_g:.8f}, Discriminator loss: {epoch_loss_d:.8f}")