GAN——生成手写数字

 《Generative Adversarial Nets》是 GAN 系列的鼻祖。在这里通过 PyTorch 实现 GAN ,并且用于手写数字生成。

摘要: 我们提出了一个新的框架,通过对抗处理来评估生成模型。其中,我们同时训练两个 model :一个是生成模型 G,用于获取数据分布;另一个是判别模型 D,用来预测样本来自训练数据而不是生成模型 G 的概率。G 的训练过程是最大化 D 犯错的概率。这个框架对应于一个极小极大的二人游戏。在任意函数 G 和 D 的空间中,存在着一个唯一的解,G 恢复训练数据的分布而 D 一直等于1/2. 在 G 和 D 都由多层感知器定义的情况下,整个系统可以通过反向传播进行训练。  

 

import time
import numpy as np
import torch
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
import torch.nn as nn
from torch.utils.data import DataLoader

if torch.cuda.is_available():
    torch.backends.cudnn.deterministic = True
要导入的包

 

#########################
## SETTINGS
#########################

# Device
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")

# Hyperparameters
random_seed = 123
generator_learning_rate = 0.001
discriminator_learning_rate = 0.001
num_epochs = 100
batch_size = 128
LATENT_DIM = 100
IMG_SHAPE = (1, 28, 28)
IMG_SIZE = 1
for x in IMG_SHAPE:
    IMG_SIZE *= x
设置超参数

 

#########################
## MNIST DATASET
#########################

train_dataset = datasets.MNIST(root='../data', 
                               train=True, 
                               transform=transforms.ToTensor(),
                               download=True)

test_dataset = datasets.MNIST(root='../data', 
                              train=False, 
                              transform=transforms.ToTensor())


train_loader = DataLoader(dataset=train_dataset, 
                          batch_size=batch_size, 
                          shuffle=True)

test_loader = DataLoader(dataset=test_dataset, 
                         batch_size=batch_size, 
                         shuffle=False)

# Checking the dataset
for images, labels in train_loader:  
    print('Image batch dimensions:', images.shape)
    print('Image label dimensions:', labels.shape)
    break

# 输出

# Image batch dimensions: torch.Size([128, 1, 28, 28])
# Image label dimensions: torch.Size([128])
加载MNIST数据集

 

##############################
## MODEL
##############################

class GAN(torch.nn.Module):
    
    def __init__(self):
        super(GAN, self).__init__()
        
        self.generator = nn.Sequential(
            nn.Linear(LATENT_DIM, 128),
            nn.LeakyReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(128, IMG_SIZE),
            nn.Tanh()
        )
        
        self.discriminator = nn.Sequential(
            nn.Linear(IMG_SIZE, 128),
            nn.LeakyReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )
    
    def generator_forward(self, z):
        img = self.generator(z)
        return img
    
    def discriminator_forward(self, img):
        pred = model.discriminator(img)
        return pred.view(-1)
GAN—Model

 

start_time = time.time()

discr_costs = []
gener_costs = []

for epoch in range(num_epochs):
    model = model.train()
    for batch_idx, (features, targets) in enumerate(train_loader):
        
        features = (features - 0.5) * 2.
        features = features.view(-1, IMG_SIZE).to(device)
        targets = targets.to(device)
        
        # Adversarial ground truths
        valid = torch.ones(targets.size(0)).float().to(device)
        fake = torch.zeros(targets.size(0)).float().to(device)
        
        ### FORWARD AND BACK PROP
        
        # ---------------------
        # Train Generator
        # ---------------------
        
        # make new images
        z = torch.zeros((targets.size(0), LATENT_DIM)).uniform_(-1.0, 1.0).to(device)
        
        # generate a batch of images
        generated_features = model.generator_forward(z)
        
        # Loss measures generators's ability to fool the discriminator
        discr_pred = model.discriminator_forward(generated_features)       
        gener_loss = F.binary_cross_entropy(discr_pred, valid)
        
        optim_gener.zero_grad()
        gener_loss.backward()
        optim_gener.step()
        
        
        # ---------------------
        # Train Discriminator
        # ---------------------
        
        # Measure discriminator's ability to classify real from samples
        discr_pred_real = model.discriminator_forward(features.view(-1, IMG_SIZE))
        real_loss = F.binary_cross_entropy(discr_pred_real, valid)        
        discr_pred_fake = model.discriminator_forward(generated_features.detach())
        fake_loss = F.binary_cross_entropy(discr_pred_fake, fake)        
        discr_loss = 0.5 * (real_loss + fake_loss)
        
        optim_discr.zero_grad()
        discr_loss.backward()
        optim_discr.step()
        
        discr_costs.append(discr_loss)
        gener_costs.append(gener_loss)
        
        ### LOGGING
        if not batch_idx % 100:
            print('Epoch: %03d/%03d | Batch %03d/%03d | Gen/Dis Loss: %.4f/%.4f'
                 %(epoch+1, num_epochs, batch_idx, len(train_loader), gener_loss, discr_loss))
        
    print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))
        
print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))   
网络训练

 

画出 generator loss 和 discriminator loss 的变化图:

plt.plot(range(len(gener_costs)), gener_costs, label='generator loss')
plt.plot(range(len(discr_costs)), discr_costs, label='discriminator loss')
plt.legend()
plt.savefig('./loss.jpg')
plt.show()

利用以上训练的 Generator 生成一些仿手写数字图片:

#########################
## VISUALIZATION
#########################

model.eval()
# Make new images
z = torch.zeros((5, LATENT_DIM)).uniform_(-1.0, 1.0).to(device)
generated_features = model.generator_forward(z)
imgs = generated_features.view(-1, 28, 28)

fig, axes = plt.subplots(nrows=1, ncols=5, figsize=(20, 2.5))


for i, ax in enumerate(axes):
    axes[i].imshow(imgs[i].detach().numpy(), cmap='binary')

再生成几次:

可以发现,以上生成的数字图片有些很清晰,但有些很模糊,不易辨认,但是结果已经让人很兴奋了~~

后续可以对GAN进行改进,从而生成质量更高的图片。

 

 

Reference

  [1] deeplearning-models——Github

  [2] Paper《Generative Adversarial Network 

 

 
posted @ 2019-08-09 14:27  虔诚的树  阅读(1996)  评论(1编辑  收藏  举报