Simple GAN的原理及Pytorch实现

follow this video: https://www.youtube.com/watch?v=OljTVUVzPpM

paper: https://papers.nips.cc/paper/2014/file/5ca3e9b122f61f8f06494c97b1afccf3-Paper.pdf

结构

 

 

核心思想

判断器的任务是尽力将假的判断为假的,将真的判断为真的;生成器的任务是使生成的越真越好。两者交替迭代训练。

 

 max部分:D要尽可能的识别出真实数据和G生成的数据

 min部分:G要尽可能使生成的数据与真实数据相同,是G分辨不出来

全局的优化目标:

 

即生成的数据分布和真实数据的分布要相同。

效果就是G能将正态分布的随机噪声生成出与数据集分布相同的样本

核心代码

        ## D: 目标:真的判断为真,假的判断为假
        ## 训练Discriminator: max log(D(x)) + log(1-D(G(z)))
        disc_real = Disc(real).view(-1)  # 将真实图片放入到判别器中
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))  # 真的判断为真

        noise = torch.randn(batch_size, z_dim).to(device)   
        fake = Gen(noise)  # 将随机噪声放入到生成器中
        disc_fake = Disc(fake).view(-1)  # 识别器判断真假
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))  # 假的应该判断为假
        lossD = (lossD_real + lossD_fake) / 2  # loss包括判真损失和判假损失


        # G: 目标:生成的越真越好
        ## 训练生成器: min log(1-D(G(z))) <-> max log(D(G(z)))
        output = Disc(fake).view(-1)   # 生成的放入识别器
        lossG = criterion(output, torch.ones_like(output))  # 与“真的”的距离,越小越好

 完整代码见 https://github.com/growvv/GAN-Pytorch/blob/main/Simple-GAN/simple_gan.py

posted @ 2021-05-15 23:12  Rogn  阅读(323)  评论(0编辑  收藏  举报