Simple GAN的原理及Pytorch实现
follow this video: https://www.youtube.com/watch?v=OljTVUVzPpM
paper: https://papers.nips.cc/paper/2014/file/5ca3e9b122f61f8f06494c97b1afccf3-Paper.pdf
结构
核心思想
判断器的任务是尽力将假的判断为假的,将真的判断为真的;生成器的任务是使生成的越真越好。两者交替迭代训练。
max部分:D要尽可能的识别出真实数据和G生成的数据
min部分:G要尽可能使生成的数据与真实数据相同,是G分辨不出来
全局的优化目标:
即生成的数据分布和真实数据的分布要相同。
效果就是G能将正态分布的随机噪声生成出与数据集分布相同的样本。
核心代码
## D: 目标:真的判断为真,假的判断为假 ## 训练Discriminator: max log(D(x)) + log(1-D(G(z))) disc_real = Disc(real).view(-1) # 将真实图片放入到判别器中 lossD_real = criterion(disc_real, torch.ones_like(disc_real)) # 真的判断为真 noise = torch.randn(batch_size, z_dim).to(device) fake = Gen(noise) # 将随机噪声放入到生成器中 disc_fake = Disc(fake).view(-1) # 识别器判断真假 lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake)) # 假的应该判断为假 lossD = (lossD_real + lossD_fake) / 2 # loss包括判真损失和判假损失 # G: 目标:生成的越真越好 ## 训练生成器: min log(1-D(G(z))) <-> max log(D(G(z))) output = Disc(fake).view(-1) # 生成的放入识别器 lossG = criterion(output, torch.ones_like(output)) # 与“真的”的距离,越小越好
完整代码见 https://github.com/growvv/GAN-Pytorch/blob/main/Simple-GAN/simple_gan.py
个性签名:时间会解决一切