PyTorch实现简单的生成对抗网络GAN
生成对抗网络是一个关于数据的生成模型:即给定训练数据,GANs能够估计数据的概率分布,基于这个概率分布产生数据样本(这些样本可能并没有出现在训练集中)。
GAN中,两个神经网络互相竞争。给定训练集X,假设是几千张猫的图片。将一个随机向量输入给生成器G(x),让G(x)生成和训练集类似的图片。判别器D(x)是一个二分类器,其试图区分真实的猫图片和生成器生成的假猫图片。总的来说,生成器的目的是学习训练数据的分布,生成尽可能真实的猫图片,以确保判别器无法区分。判别器需要不断地学习生成器的“造假图片”,以防止自己被欺骗。
判别器与生成器不断“斗智斗勇”的过程中,生成器或多或少地学习到了训练数据的真实分布,已经能生成一些以假乱真的图片了;而判别器最终已经无法判断猫的图片是真实的,还是来自于生成器。从某种意义上来说,生成器和判别器都希望对方“失败”。
从另外一个角度来说,判别器实际上是在指导生成器,告诉生成器: 真的猫图片到底什么样?模型训练的最终结果是生成器能够学习到数据的分布,最终可以生成近似真的猫图片。GANs的训练方法类似于博弈论中的MinMax算法,生成器和判别器最终达到了纳什均衡。(摘自https://zhuanlan.zhihu.com/p/74663048)
生成对抗网络(Generative Adversarial Network, GAN)包括生成网络和对抗网络两部分。生成网络像自动编码器的解码器,能够生成数据,比如生成一张图片。对抗网络用来判断数据的真假,比如是真图片还是假图片,真图片是拍摄得到的,假图片是生成网络生成的。
以下程序主要来自廖星宇的《深度学习之PyTorch》的第六章,本文对原代码进行了改进:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 | import torch from torch import nn import torchvision.transforms as tfs from torch.utils.data import DataLoader from torchvision.datasets import MNIST import numpy as np import matplotlib.pyplot as plt def preprocess_img(x): x = tfs.ToTensor()(x) # x (0., 1.) return (x - 0.5 ) / 0.5 # x (-1., 1.) def deprocess_img(x): # x (-1., 1.) return (x + 1.0 ) / 2.0 # x (0., 1.) def discriminator(): net = nn.Sequential( nn.Linear( 784 , 256 ), nn.LeakyReLU( 0.2 ), nn.Linear( 256 , 256 ), nn.LeakyReLU( 0.2 ), nn.Linear( 256 , 1 ), ) return net def generator(noise_dim): net = nn.Sequential( nn.Linear(noise_dim, 1024 ), nn.ReLU( True ), nn.Linear( 1024 , 1024 ), nn.ReLU( True ), nn.Linear( 1024 , 784 ), nn.Tanh(), ) return net def discriminator_loss(logits_real, logits_fake): # 判别器的loss size = logits_real.shape[ 0 ] true_labels = torch.ones(size, 1 ). float () false_labels = torch.zeros(size, 1 ). float () bce_loss = nn.BCEWithLogitsLoss() loss = bce_loss(logits_real, true_labels) + bce_loss(logits_fake, false_labels) return loss def generator_loss(logits_fake): # 生成器的 loss size = logits_fake.shape[ 0 ] true_labels = torch.ones(size, 1 ). float () bce_loss = nn.BCEWithLogitsLoss() loss = bce_loss(logits_fake, true_labels) # 假图与真图的误差。训练的目的是减小误差,即让假图接近真图。 return loss # 使用 adam 来进行训练,beta1 是 0.5, beta2 是 0.999 def get_optimizer(net, LearningRate): optimizer = torch.optim.Adam(net.parameters(), lr = LearningRate, betas = ( 0.5 , 0.999 )) return optimizer def train_a_gan(D_net, G_net, D_optimizer, G_optimizer, discriminator_loss, generator_loss, noise_size, num_epochs, num_img): f, a = plt.subplots(num_img, num_img, figsize = (num_img, num_img)) plt.ion() # Turn the interactive mode on, continuously plot for epoch in range (num_epochs): for iteration, (x, _) in enumerate (train_data): bs = x.shape[ 0 ] # 训练判别网络 real_data = x.view(bs, - 1 ) # 真实数据 logits_real = D_net(real_data) # 判别网络得分 rand_noise = (torch.rand(bs, noise_size) - 0.5 ) / 0.5 # -1 ~ 1 的均匀分布 fake_images = G_net(rand_noise) # 生成的假的数据 logits_fake = D_net(fake_images) # 判别网络得分 d_total_error = discriminator_loss(logits_real, logits_fake) # 判别器的 loss D_optimizer.zero_grad() d_total_error.backward() D_optimizer.step() # 优化判别网络 # 训练生成网络 rand_noise = (torch.rand(bs, noise_size) - 0.5 ) / 0.5 # -1 ~ 1 的均匀分布 fake_images = G_net(rand_noise) # 生成的假的数据 gen_logits_fake = D_net(fake_images) g_error = generator_loss(gen_logits_fake) # 生成网络的 loss G_optimizer.zero_grad() g_error.backward() G_optimizer.step() # 优化生成网络 if iteration % 20 = = 0 : print ( 'Epoch: {:2d} | Iter: {:<4d} | D: {:.4f} | G:{:.4f}' . format (epoch, iteration, d_total_error.data.numpy(), g_error.data.numpy())) imgs_numpy = deprocess_img(fake_images.data.cpu().numpy()) for i in range (num_img * * 2 ): a[i / / num_img][i % num_img].imshow(np.reshape(imgs_numpy[i], ( 28 , 28 )), cmap = 'gray' ) a[i / / num_img][i % num_img].set_xticks(()) a[i / / num_img][i % num_img].set_yticks(()) plt.suptitle( 'epoch: {} iteration: {}' . format (epoch, iteration)) plt.pause( 0.01 ) plt.ioff() plt.show() if __name__ = = '__main__' : EPOCH = 5 BATCH_SIZE = 128 LR = 5e - 4 NOISE_DIM = 96 NUM_IMAGE = 4 # for showing images when training train_set = MNIST(root = '/Users/wangpeng/Desktop/all/CS/Courses/Deep Learning/mofan_PyTorch/mnist/' , train = True , download = False , transform = preprocess_img) train_data = DataLoader(train_set, batch_size = BATCH_SIZE, shuffle = True ) D = discriminator() G = generator(NOISE_DIM) D_optim = get_optimizer(D, LR) G_optim = get_optimizer(G, LR) train_a_gan(D, G, D_optim, G_optim, discriminator_loss, generator_loss, NOISE_DIM, EPOCH, NUM_IMAGE) |
效果:
程序的理解:
训练Discriminator:
训练Generatord:
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 基于Microsoft.Extensions.AI核心库实现RAG应用
· Linux系列:如何用heaptrack跟踪.NET程序的非托管内存泄露
· 开发者必知的日志记录最佳实践
· SQL Server 2025 AI相关能力初探
· Linux系列:如何用 C#调用 C方法造成内存泄露
· Manus爆火,是硬核还是营销?
· 终于写完轮子一部分:tcp代理 了,记录一下
· 别再用vector<bool>了!Google高级工程师:这可能是STL最大的设计失误
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 单元测试从入门到精通