07_WGAN

  GAN现在是热门的研究课题,一般有两种类型的GAN研究方向:一种是在各种各样的问题中应用GAN,另一种是试图稳定GAN的训练。 稳定GAN的训练过程是一个非常重要的事情

 

  原始GAN训练过程中经常遇到的问题:

  • 模式崩溃, 生成器生成非常窄的分布,仅覆盖数据分布中的单一模式。 模式崩溃的含义是生成器只能生成非常相似的样本(例如, MNIST中的单个数字),即生成的样本不是多样的。
  • 没有指标可以告诉我们收敛情况。 生成器和判别器的loss并没有告诉我们任何收敛相关信息。当然,我们可以通过不时地查看生成器生成的数据来监控训练进度。但是,这是一个手动过程。因此, 我们需要有一个可解释的指标可以告诉我们有关训练的进度。

  原始GAN形式的问题,一句话概括:判别器越好,生成器梯度消失越严重。 GAN网络训练的重点在于均衡生成器与判别器,若判别器太强, loss没有再下降,生成器学习不到东西,生成图像的质量便不会再有提升。 在最优判别器的下,我们可以把原始GAN定义的生成器loss等价变换为最小化真实分布与生成分布之间的JS散度。我们越训练判别器,它就越接近最优,最小化生成器的loss也就会越近似于最小化真实分布与生成分布之间的JS散度。

  关键点就在于如何评价生成图片和真实图片之间的距离 :

 

 

JS散度存在的问题

  如果希望两个分布之间越接近它们的JS散度越小,我们通过优化JS散度就能将生成分布拉向真实分布,最终以假乱真。这个希望在两个分布有所重叠的时候是成立的,但是如果两个分布完全没有重叠的部分,或者它们重叠的部分可忽略,那它们的JS散度就一直是 log2 。

原始GAN存在的问题

  在原始GAN的(近似)最优判别器下,生成器loss面临梯度消失问题。也面临优化目标荒谬、梯度不稳定、对多样性与准确性惩罚不平衡导致mode collapse问题。
  原始GAN问题的根源可以归结为两点:
    一是等价优化的距离衡量(JS散度)不合理,
    二是生成器随机初始化后的生成分布很难与真实分布有不可
忽略的重叠。
  Wasserstein GAN(WGAN)就是希望解决上述两个问题。

  注:WGAN(Wasserstein GAN)论文:https://arxiv.org/abs/1701.07875

解决原始GAN问题的方法:

  解决问题的关键在于使用:Wasserstein距离;
  衡量两个分布之间的距离Wasserstein距离 优越性在于:即使两个分布没有任何重叠,也可以反应他们之间的距离.
Wasserstein距离
  
P和Q为两个分布: P分布为一堆土, Q分布为要移到的目标,那么要移动P达到Q,哪种距离更小呢?

 

 使用Wasserstein距离, 无论两个分布多远,都有梯度,都是可以更新的.

 WGAN的设计

原始的生成对抗网络,所要优化的目标函数为:

 此目标函数可以分为两部分来看:
①固定生成器 G,优化判别器 D, 则上式可以写成如下形式:

 可以转化为最小化形式:

 ②固定判别器 D,优化生成器 G,舍去前面的常数,相当于最小化:

 也相当于最小化:

 我们要构建一个判别器 D,使得 D 的参数不超过某个固定的常数,最后一层是非线性层,并且使下面式子最大化:

 这是一种数学的近似,同要求梯度变化的不要太猛烈。

   那么怎么梯度更新呢?因为D有了限制,无法直接利用SGD。这里引入一种方法: Weight clipping就是强制令权重w 限制在c ~ -c之间。在参数更新后,如果w>c,则令w=c, 如果w<-c,则令w=-c。

WGAN的实现
  WGAN与原始GAN第一种形式相比,只改了四点:

    一、判别器最后一层去掉sigmoid

    二、生成器和判别器的loss不取log

    三、每次更新判别器的参数之后把它们的值截断到不超过一个固定常数c

    四、不要用基于动量的优化算法(包括momentum和Adam),推荐RMSProp,但是RMSProp适合梯度不稳定的情况。

  WGAN本作引入了Wasserstein距离,由于它相对KL散度与JS散度具有优越的平滑特性,理论上可以解决梯度消失问题。接着通过数学变换将Wasserstein距离写成可求解的形式,利用一个参数数值范围受限的判别器神经网络来较大化这个形式,就可以近似Wasserstein距离。
  WGAN既解决了训练不稳定的问题,也提供了一个可靠的训练进程指标,而且该指标确实与生成样本的质量高度相关。


WGAN的代码实现(pytorch)

  1 import torch
  2 import torch.nn as nn
  3 import torch.nn.functional as F
  4 import torch.optim as optim
  5 import numpy as np
  6 import matplotlib.pyplot as plt
  7 import torchvision
  8 from torchvision import transforms
  9 
 10 
 11 
 12 transform = transforms.Compose([
 13     transforms.ToTensor(),
 14     transforms.Normalize(mean=0.5, std=0.5)
 15 ])
 16 
 17 train_ds = torchvision.datasets.MNIST('data/',
 18                                       train=True,
 19                                       transform=transform,
 20                                       download=True)
 21 
 22 train_dl = torch.utils.data.DataLoader(train_ds, batch_size=512, shuffle=True)
 23 
 24 #定义生成器
 25 
 26 # 使用长度为 100 的noise作为输入,也可以使用torch.randn(batchsize, 100, 1, 1)
 27 # 生成 1*28*28
 28 class Generator(nn.Module):
 29     def __init__(self):
 30         super(Generator, self).__init__()
 31         self.linear1 = nn.Linear(100, 256*7*7)
 32         self.bn1 = nn.BatchNorm1d(256*7*7)
 33         self.deconv1 = nn.ConvTranspose2d(256, 128,
 34                                           kernel_size=(3, 3),
 35                                           stride=1,
 36                                           padding=1)     # (128, 7, 7)
 37         self.bn2 = nn.BatchNorm2d(128)
 38         
 39         self.deconv2 = nn.ConvTranspose2d(128, 64,
 40                                           kernel_size=(4, 4),
 41                                           stride=2,
 42                                           padding=1)     # (64, 14, 14)
 43         self.bn3 = nn.BatchNorm2d(64)
 44         
 45         self.deconv3 = nn.ConvTranspose2d(64, 1,
 46                                           kernel_size=(4, 4),
 47                                           stride=2,
 48                                           padding=1)     # (1, 28, 28)
 49     
 50     def forward(self, x):
 51         x = F.relu(self.linear1(x))
 52         x = self.bn1(x)
 53         x = x.view(-1, 256, 7, 7)
 54         x = F.relu(self.deconv1(x))
 55         x = self.bn2(x)
 56         x = F.relu(self.deconv2(x))
 57         x = self.bn3(x)
 58         x = torch.tanh(self.deconv3(x))
 59         return x
 60     
 61 # input: 1, 28, 28的图片
 62 class Discriminator(nn.Module):
 63     def __init__(self):
 64         super(Discriminator, self).__init__()
 65         self.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=2)
 66         self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2)
 67         self.bn = nn.BatchNorm2d(128)
 68         self.fc = nn.Linear(128*6*6, 1)
 69     def forward(self, x):
 70         x = F.dropout2d(F.leaky_relu(self.conv1(x)))
 71         x = F.dropout2d(F.leaky_relu(self.conv2(x)))    # (batch, 128, 6, 6)
 72         x = self.bn(x)
 73         x = x.view(-1, 128*6*6)  # (batch, 128, 6, 6)-->  (batch, 128*6*6)
 74         x = self.fc(x)
 75         return x
 76 
 77 device = "cuda" if torch.cuda.is_available() else "cpu"
 78 #device = "cpu"
 79 gen = Generator().to(device)
 80 dis = Discriminator().to(device)
 81 
 82 #loss_fn = torch.nn.BCELoss()  # 损失函数
 83 d_optimizer = torch.optim.RMSprop(dis.parameters(), lr=2e-4)
 84 g_optimizer = torch.optim.RMSprop(gen.parameters(), lr=2e-4)
 85 
 86 def generate_and_save_images(model, epoch, test_input):
 87     predictions = np.squeeze(model(test_input).cpu().numpy())
 88     fig = plt.figure(figsize=(4, 4))
 89     for i in range(predictions.shape[0]):
 90         plt.subplot(4, 4, i+1)
 91         plt.imshow((predictions[i] + 1)/2, cmap='gray')
 92         plt.axis('off')
 93     plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
 94     plt.show()
 95     
 96 test_input = torch.randn(16, 100, device=device)
 97 
 98 D_loss = []
 99 G_loss = []
100 
101 #开始训练
102 for epoch in range(30):
103     D_epoch_loss=0
104     G_epoch_loss=0
105     count = len(train_dl.dataset)
106     for step, (img, _) in enumerate(train_dl):
107         img = img.to(device)
108         size=img.shape[0]
109         random_seed = torch.randn(size, 100, device=device)
110         
111         
112         for p in dis.parameters():
113             p.data.clamp_(-0.01, 0.01)
114         d_optimizer.zero_grad()
115         real_output = dis(img).mean()
116         
117         generated_img = gen(random_seed)
118         fake_output = dis(generated_img.detach()).mean()
119         
120         disc_loss = -real_output + fake_output
121         disc_loss.backward()
122         d_optimizer.step()
123         
124         g_optimizer.zero_grad()
125         fake_output = dis(generated_img)
126         gen_loss = -torch.mean(fake_output)
127         gen_loss.backward()
128         g_optimizer.step()
129         
130         with torch.no_grad():
131             D_epoch_loss += disc_loss.item()
132             G_epoch_loss += gen_loss.item()
133     with torch.no_grad():        
134         D_epoch_loss /= count
135         G_epoch_loss /= count
136         D_loss.append(D_epoch_loss)
137         G_loss.append(G_epoch_loss)
138         generate_and_save_images(gen, epoch, test_input)
139     print('Epoch: ', epoch)
140 
141 plt.plot(D_loss, label='D_loss')
142 plt.plot(G_loss, label='G_loss')
143 plt.legend()
View Code

 



posted @ 2022-12-31 18:12  赵家小伙儿  阅读(175)  评论(0编辑  收藏  举报