GAN系列2:利用简单的GAN生成手写体图像

目的:

基于pytorch利用GAN生成手写体图像;

系列内容:

一、学习GAN基本架构;

二、生成器和判别器的训练;

三、GAN中生成器和判别器的损失函数;

四、各种应用GAN的架构;

 

训练判别器:

1)得到真实数据和真实标签(真实标签标记为1);真实标签的长度应该等于batch size的长度;

2)前向传播,将真实的数据传给班别器,得到来自真实数据的真实输出;

3)计算判别器损失从真实的输出和标签中,并且反向传播它;

4)使用生成的数据,通过生成器进行前向传播,计算生成数据的输出和生成数据的损失;反向传播生成数据的损失;通过计算真实数据损失和生成数据损失,计算整体损失;

5)更新判别器的参数;

 

训练生成器:

1)通过前向传播得到生成器的生成数据;标记为1;

2)通过判别器做前向传播;

3)计算损失并且反向传播;

4)更新并优化生成器参数;

 

文件结构:

├───input
├───outputs
└───src
        vanilla_gan.py

 

代码实现:

我们将在vanilla_gan.py中实现我们所有的代码;

1) 导入包

 1 import torch
 2 import torch.nn as nn
 3 import torchvision.transforms as transforms
 4 import torch.optim as optim
 5 import torchvision.datasets as datasets
 6 import imageio
 7 import numpy as np
 8 import matplotlib
 9 from torchvision.utils import make_grid, save_image
10 from torch.utils.data import DataLoader
11 from matplotlib import pyplot as plt
12 from tqdm import tqdm
13 matplotlib.style.use('ggplot')

make_grid()和save_image有助于图像的存储;

 

2)学习参数的定义:

1 # learning parameters
2 batch_size = 512  
3 epochs = 200 
4 sample_size = 64 # fixed sample size
5 nz = 128 # latent vector size
6 k = 1 # number of steps to apply to the discriminator
7 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

 

3)数据集的准备

1 transform = transforms.Compose([
2                                 transforms.ToTensor(),
3                                 transforms.Normalize((0.5,),(0.5,)),
4 ])
5 to_pil_image = transforms.ToPILImage()

Line5将数据转换为PIL图像格式;这是必要的;当我们想存储GAN生成的图像;在存储之前,必须转换为PIL图像格式;

train_data = datasets.MNIST(
    root='../input/data',
    train=True,
    download=True,
    transform=transform
)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

定义GAN的判别器和分类器

生成器:

使用简单的线性层;

 

 1 class Generator(nn.Module):
 2     def __init__(self, nz):
 3         super(Generator, self).__init__()
 4         self.nz = nz
 5         self.main = nn.Sequential(
 6             nn.Linear(self.nz, 256),  #输入的特征为128,输出256
 7             nn.LeakyReLU(0.2),
8 nn.Linear(256, 512), 9 nn.LeakyReLU(0.2),
10 nn.Linear(512, 1024), 11 nn.LeakyReLU(0.2),
12 nn.Linear(1024, 784), 13 nn.Tanh(), 14 )
15 def forward(self, x): 16 return self.main(x).view(-1, 1, 28, 28)

 

判别器:

 1 class Discriminator(nn.Module):
 2     def __init__(self):
 3         super(Discriminator, self).__init__()
 4         self.n_input = 784
 5         self.main = nn.Sequential(
 6             nn.Linear(self.n_input, 1024),
 7             nn.LeakyReLU(0.2),
 8             nn.Dropout(0.3),
 9             nn.Linear(1024, 512),
10             nn.LeakyReLU(0.2),
11             nn.Dropout(0.3),
12             nn.Linear(512, 256),
13             nn.LeakyReLU(0.2),
14             nn.Dropout(0.3),
15             nn.Linear(256, 1),
16             nn.Sigmoid(),
17         )
18     def forward(self, x):
19         x = x.view(-1, 784)
20         return self.main(x)

初始化NN、定义优化器

1 generator = Generator(nz).to(device)
2 discriminator = Discriminator().to(device)
3 print('##### GENERATOR #####')
4 print(generator)
5 print('######################')
6 print('\n##### DISCRIMINATOR #####')
7 print(discriminator)
8 print('######################')

优化器:

1 # optimizers
2 optim_g = optim.Adam(generator.parameters(), lr=0.0002)
3 optim_d = optim.Adam(discriminator.parameters(), lr=0.0002)

损失函数:

1 # loss function
2 criterion = nn.BCELoss()

每次迭代后的损失存储:

1 losses_g = [] # to store generator loss after each epoch
2 losses_d = [] # to store discriminator loss after each epoch
3 images = [] # to store images generatd by the generator

定义一些其他函数:

在GAN训练过程中,我们需要真实图像和生成图像的标记,用于计算损失;

定义两个函数,用于生成1和0

1 # to create real labels (1s)
2 def label_real(size):
3     data = torch.ones(size, 1)
4     return data.to(device)
5 # to create fake labels (0s)
6 def label_fake(size):
7     data = torch.zeros(size, 1)
8     return data.to(device)

在生成器中,我们也需要一个噪音向量,这个向量应该等于nz(128)用于生成图像;

1 # function to create the noise vector
2 def create_noise(sample_size, nz):
3     return torch.randn(sample_size, nz).to(device)

这个函数接受两个参数:sample_size以及nz。

它将返回一个随机向量,后续用于输入生成器中生成假的图像;

 

最后保存生成的图像

1 # to save the images generated by the generator
2 def save_generator_image(image, path):
3     save_image(image, path)

 

训练判别器的函数:

 1 # function to train the discriminator network
 2 def train_discriminator(optimizer, data_real, data_fake):
 3     b_size = data_real.size(0)
 4     real_label = label_real(b_size)
 5     fake_label = label_fake(b_size)
 6     optimizer.zero_grad()
 7     output_real = discriminator(data_real)
 8     loss_real = criterion(output_real, real_label)
 9     output_fake = discriminator(data_fake)
10     loss_fake = criterion(output_fake, fake_label)
11     loss_real.backward()
12     loss_fake.backward()
13     optimizer.step()
14     return loss_real + loss_fake

训练GAN

1 # create the noise vector
2 noise = create_noise(sample_size, nz)

 

1 generator.train()
2 discriminator.train()

开始训练:

 1 for epoch in range(epochs):
 2     loss_g = 0.0
 3     loss_d = 0.0
 4     for bi, data in tqdm(enumerate(train_loader), total=int(len(train_data)/train_loader.batch_size)):
 5         image, _ = data
 6         image = image.to(device)
 7         b_size = len(image)
 8         # run the discriminator for k number of steps
 9         for step in range(k):
10             data_fake = generator(create_noise(b_size, nz)).detach()
11             data_real = image
12             # train the discriminator network
13             loss_d += train_discriminator(optim_d, data_real, data_fake)
14         data_fake = generator(create_noise(b_size, nz))
15         # train the generator network
16         loss_g += train_generator(optim_g, data_fake)
17     # create the final fake image for the epoch
18     generated_img = generator(noise).cpu().detach()
19     # make the images as grid
20     generated_img = make_grid(generated_img)
21     # save the generated torch tensor models to disk
22     save_generator_image(generated_img, f"../outputs/gen_img{epoch}.png")
23     images.append(generated_img)
24     epoch_loss_g = loss_g / bi # total generator loss for the epoch
25     epoch_loss_d = loss_d / bi # total discriminator loss for the epoch
26     losses_g.append(epoch_loss_g)
27     losses_d.append(epoch_loss_d)
28     
29     print(f"Epoch {epoch} of {epochs}")
30     print(f"Generator loss: {epoch_loss_g:.8f}, Discriminator loss: {epoch_loss_d:.8f}")

 

posted @ 2021-05-22 18:59  hi_mxd  阅读(597)  评论(0编辑  收藏  举报