GAN——生成手写数字
《Generative Adversarial Nets》是 GAN 系列的鼻祖。在这里通过 PyTorch 实现 GAN ,并且用于手写数字生成。
摘要: 我们提出了一个新的框架,通过对抗处理来评估生成模型。其中,我们同时训练两个 model :一个是生成模型 G,用于获取数据分布;另一个是判别模型 D,用来预测样本来自训练数据而不是生成模型 G 的概率。G 的训练过程是最大化 D 犯错的概率。这个框架对应于一个极小极大的二人游戏。在任意函数 G 和 D 的空间中,存在着一个唯一的解,G 恢复训练数据的分布而 D 一直等于1/2. 在 G 和 D 都由多层感知器定义的情况下,整个系统可以通过反向传播进行训练。
import time import numpy as np import torch import torch.nn.functional as F from torchvision import datasets from torchvision import transforms import torch.nn as nn from torch.utils.data import DataLoader if torch.cuda.is_available(): torch.backends.cudnn.deterministic = True
######################### ## SETTINGS ######################### # Device device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu") # Hyperparameters random_seed = 123 generator_learning_rate = 0.001 discriminator_learning_rate = 0.001 num_epochs = 100 batch_size = 128 LATENT_DIM = 100 IMG_SHAPE = (1, 28, 28) IMG_SIZE = 1 for x in IMG_SHAPE: IMG_SIZE *= x
######################### ## MNIST DATASET ######################### train_dataset = datasets.MNIST(root='../data', train=True, transform=transforms.ToTensor(), download=True) test_dataset = datasets.MNIST(root='../data', train=False, transform=transforms.ToTensor()) train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True) test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False) # Checking the dataset for images, labels in train_loader: print('Image batch dimensions:', images.shape) print('Image label dimensions:', labels.shape) break # 输出 # Image batch dimensions: torch.Size([128, 1, 28, 28]) # Image label dimensions: torch.Size([128])
############################## ## MODEL ############################## class GAN(torch.nn.Module): def __init__(self): super(GAN, self).__init__() self.generator = nn.Sequential( nn.Linear(LATENT_DIM, 128), nn.LeakyReLU(inplace=True), nn.Dropout(p=0.5), nn.Linear(128, IMG_SIZE), nn.Tanh() ) self.discriminator = nn.Sequential( nn.Linear(IMG_SIZE, 128), nn.LeakyReLU(inplace=True), nn.Dropout(p=0.5), nn.Linear(128, 1), nn.Sigmoid() ) def generator_forward(self, z): img = self.generator(z) return img def discriminator_forward(self, img): pred = model.discriminator(img) return pred.view(-1)
start_time = time.time() discr_costs = [] gener_costs = [] for epoch in range(num_epochs): model = model.train() for batch_idx, (features, targets) in enumerate(train_loader): features = (features - 0.5) * 2. features = features.view(-1, IMG_SIZE).to(device) targets = targets.to(device) # Adversarial ground truths valid = torch.ones(targets.size(0)).float().to(device) fake = torch.zeros(targets.size(0)).float().to(device) ### FORWARD AND BACK PROP # --------------------- # Train Generator # --------------------- # make new images z = torch.zeros((targets.size(0), LATENT_DIM)).uniform_(-1.0, 1.0).to(device) # generate a batch of images generated_features = model.generator_forward(z) # Loss measures generators's ability to fool the discriminator discr_pred = model.discriminator_forward(generated_features) gener_loss = F.binary_cross_entropy(discr_pred, valid) optim_gener.zero_grad() gener_loss.backward() optim_gener.step() # --------------------- # Train Discriminator # --------------------- # Measure discriminator's ability to classify real from samples discr_pred_real = model.discriminator_forward(features.view(-1, IMG_SIZE)) real_loss = F.binary_cross_entropy(discr_pred_real, valid) discr_pred_fake = model.discriminator_forward(generated_features.detach()) fake_loss = F.binary_cross_entropy(discr_pred_fake, fake) discr_loss = 0.5 * (real_loss + fake_loss) optim_discr.zero_grad() discr_loss.backward() optim_discr.step() discr_costs.append(discr_loss) gener_costs.append(gener_loss) ### LOGGING if not batch_idx % 100: print('Epoch: %03d/%03d | Batch %03d/%03d | Gen/Dis Loss: %.4f/%.4f' %(epoch+1, num_epochs, batch_idx, len(train_loader), gener_loss, discr_loss)) print('Time elapsed: %.2f min' % ((time.time() - start_time)/60)) print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))
画出 generator loss 和 discriminator loss 的变化图:
plt.plot(range(len(gener_costs)), gener_costs, label='generator loss') plt.plot(range(len(discr_costs)), discr_costs, label='discriminator loss') plt.legend() plt.savefig('./loss.jpg') plt.show()
利用以上训练的 Generator 生成一些仿手写数字图片:
######################### ## VISUALIZATION ######################### model.eval() # Make new images z = torch.zeros((5, LATENT_DIM)).uniform_(-1.0, 1.0).to(device) generated_features = model.generator_forward(z) imgs = generated_features.view(-1, 28, 28) fig, axes = plt.subplots(nrows=1, ncols=5, figsize=(20, 2.5)) for i, ax in enumerate(axes): axes[i].imshow(imgs[i].detach().numpy(), cmap='binary')
再生成几次:
可以发现,以上生成的数字图片有些很清晰,但有些很模糊,不易辨认,但是结果已经让人很兴奋了~~
后续可以对GAN进行改进,从而生成质量更高的图片。
Reference