生成模型

AE

VAE

GAN

应用目标

生成式任务(生成、重建、超分辨率、风格迁移、补全、上采样等)

核心思想

生成器G和判别器D的一代代博弈

  • 生成器G:生成网络,通过输入生成图像,希望生成的数据可以让D分辨不出来
  • 判别器D:二分类网络,将生成器生成图像作为负样本,真实图像作为正样本,希望尽可能分辨出G生成的数据和真实数据的分布
  • 判别器D训练:给定G,通过G生成图像产生负样本,并结合真实图像作为正样本来训练D
  • 生成器G训练:给定D,以使得D对G生成图像的评分尽可能接近正样本作为目标来训练G
  • G和D的训练过程交替进行,对抗过程使得G生成的图像越来越逼真,D分辨真假的能力越来越强

image

算法原理

GAN的精妙之处:对生成模型损失函数的处理

G(生成网络):接受一个随机噪声\(z\),通过该噪声生成图片,记作\(G(z)\)

输入噪声的随机性可以带来生成图像的多样性

D(判别网络):输入参数为\(x\)\(x\)代表一张图片,输出\(D(x)\)代表\(x\)为真实图片的概率,如果为1,就代表100%是真实图片,若为0,则代表不可能是真实的图片

问题分析

目标函数如何分析?
image

对数函数:在其定义域内是单调递增函数,数据取对数不改变数据间的相对关系,使用\(log\)后,可放大损失,便于计算和优化

  • 前半部分公式 $$E_{x\backsim p_{data}(x)}[logD(x)]$$

    • \(D(x)\)表示判别器对真实图片的判别,取对数函数后目的是为了其值趋于0,也就是\(D(x)\)趋于1,也就是放大损失
    • \(E_{x\backsim p_{data}(x)}\)表示期望\(x\)\(p_{data}\)中获取
      • \(x\)表示真实的数据(图片)
      • \(P_{data}\)表示真实数据的分布
    • 综上所述,前半部分公式
      • 含义:判别器判别出真实数据的概率。
      • 优化目标:使得该概率越大越好
  • 后半部分公式 $$E_{z\backsim p_{z}(z)}[log(1-D(G(z)))] $$

    • \(E_{z\backsim p_{z}(z)}\)表示期望\(z\)\(p_{z}\)中获取
      • \(z\)表示随机的噪声
      • \(P_{z}(z)\)表示生成随机噪声的分布
    • 对于判别器D来说,若输入的是生成数据(\(D(G(z))\)),其目标便是将生成数据判定为0(即\(D(G(z))=0\)),也就是希望\(log(1-D(G(z)))\)越大越好
    • 对于生成器G来说,其目的是生成的数据被判别器识别为真(即\(D(G(z))=1\)),也就是希望\(log(1-D(G(z)))\)越小越好
    • 综上所述,D和G的优化目标相反
  • 总结

    • 对于判别器D,最大化\(logD(x)\)\(log(1-D(G(z)))\),从而达到最大化\(V(D,G)\)
    • 对于生成器G,最小化\(log(1-D(G(z)))\),从而达到最小化\(V(D,G)\) 的目标

先更新D参数指导G方向
公式解析:$$min_{G}max_{D}V(D,G)=E_{x\backsim p_{data}(x)}[logD(x)]+E_{z\backsim p_{z}(z)}[log(1-D(G(z)))] $$

  • 先算 \(max_{D}V(D,G)=E_{x\backsim p_{data}(x)}[logD(x)]+E_{z\backsim p_{z}(z)}[log(1-D(G(z)))]\),固定G,用D区分正负样本,因此是\(max_{D}\)
  • 后算 整体 ,判别式D固定不动,通过调整生成器G,希望判别器不失误,尽可能不让判别器区分出正负样本(提高生成图像的真实性)

每训练出一个生成器,就要生出一个判别器,判别器要使真实图像的值尽可能的大,生成图像的值尽可能的小。也就是说让判别器具有更强的判别能力。是个动态的问题,跟以前损失函数恒定不变的思想不同

如何生成图片?

G和D应该如何设置?

如何进行训练?
image

点击查看代码
for 迭代 in range(迭代总数):
   for batch in range(batch_size):
       新batch = input1的batch + input2的batch  # (batch加倍)
	   for 轮数 in range(判别器总轮数):
             步骤一
	   步骤二

损失函数

生成器损失(能否生成近似真实图片并使得判别器将生成图片判定为真):通过判别器的输出来计算

判别器损失(能否正确区分生成的图片和真实图片):判别器输出为一个概率值,通过交叉熵计算

代码实现

加载数据集(并可视化)

import numpy as np
import torch
import matplotlib.pyplot as plt
from torchvision import datasets
import torchvision.transforms as transforms

# number of subprocesses to use for data loading
num_workers = 0
# how many samples per batch to load
batch_size = 64

# convert data to torch.FloatTensor
transform = transforms.ToTensor()

# get the training datasets
train_data = datasets.MNIST(root='data', train=True,
                                   download=True, transform=transform)

# prepare data loader
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size,
                                           num_workers=num_workers)
# 可视化
dataiter = iter(train_loader)
images, labels = dataiter.next()
images = images.numpy()

# get one image from the batch
img = np.squeeze(images[0])

fig = plt.figure(figsize = (3,3)) 
ax = fig.add_subplot(111)
ax.imshow(img, cmap='gray')

image

import torch.nn as nn
import torch.nn.functional as F

class Discriminator(nn.Module):

    def __init__(self, input_size, hidden_dim, output_size):
        super(Discriminator, self).__init__()
        
        # define hidden linear layers
        self.fc1 = nn.Linear(input_size, hidden_dim*4)
        self.fc2 = nn.Linear(hidden_dim*4, hidden_dim*2)
        self.fc3 = nn.Linear(hidden_dim*2, hidden_dim)
        
        # final fully-connected layer
        self.fc4 = nn.Linear(hidden_dim, output_size)
        
        # dropout layer 
        self.dropout = nn.Dropout(0.3)
        
        
    def forward(self, x):
        # flatten image
        x = x.view(-1, 28*28)
        # all hidden layers
        x = F.leaky_relu(self.fc1(x), 0.2) # (input, negative_slope=0.2)
        x = self.dropout(x)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = self.dropout(x)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = self.dropout(x)
        # final layer
        out = self.fc4(x)

        return out

class Generator(nn.Module):

    def __init__(self, input_size, hidden_dim, output_size):
        super(Generator, self).__init__()
        
        # define hidden linear layers
        self.fc1 = nn.Linear(input_size, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim*2)
        self.fc3 = nn.Linear(hidden_dim*2, hidden_dim*4)
        
        # final fully-connected layer
        self.fc4 = nn.Linear(hidden_dim*4, output_size)
        
        # dropout layer 
        self.dropout = nn.Dropout(0.3)

    def forward(self, x):
        # all hidden layers
        x = F.leaky_relu(self.fc1(x), 0.2) # (input, negative_slope=0.2)
        x = self.dropout(x)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = self.dropout(x)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = self.dropout(x)
        # final layer with tanh applied
        out = F.tanh(self.fc4(x))

        return out

# Discriminator hyperparams

# Size of input image to discriminator (28*28)
input_size = 784
# Size of discriminator output (real or fake)
d_output_size = 1
# Size of last hidden layer in the discriminator
d_hidden_size = 32

# Generator hyperparams

# Size of latent vector to give to generator
z_size = 100
# Size of discriminator output (generated image)
g_output_size = 784
# Size of first hidden layer in the generator
g_hidden_size = 32

# instantiate discriminator and generator
D = Discriminator(input_size, d_hidden_size, d_output_size)
G = Generator(z_size, g_hidden_size, g_output_size)

# check that they are as you expect
print(D)
print()
print(G)

Discriminator(
  (fc1): Linear(in_features=784, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=64, bias=True)
  (fc3): Linear(in_features=64, out_features=32, bias=True)
  (fc4): Linear(in_features=32, out_features=1, bias=True)
  (dropout): Dropout(p=0.3)
)

Generator(
  (fc1): Linear(in_features=100, out_features=32, bias=True)
  (fc2): Linear(in_features=32, out_features=64, bias=True)
  (fc3): Linear(in_features=64, out_features=128, bias=True)
  (fc4): Linear(in_features=128, out_features=784, bias=True)
  (dropout): Dropout(p=0.3)
)

# Calculate losses
def real_loss(D_out, smooth=False):
    batch_size = D_out.size(0)
    # label smoothing
    if smooth:
        # smooth, real labels = 0.9
        labels = torch.ones(batch_size)*0.9
    else:
        labels = torch.ones(batch_size) # real labels = 1
        
    # numerically stable loss
    criterion = nn.BCEWithLogitsLoss()
    # calculate loss
    loss = criterion(D_out.squeeze(), labels)
    return loss

def fake_loss(D_out):
    batch_size = D_out.size(0)
    labels = torch.zeros(batch_size) # fake labels = 0
    criterion = nn.BCEWithLogitsLoss()
    # calculate loss
    loss = criterion(D_out.squeeze(), labels)
    return loss

import torch.optim as optim

# Optimizers
lr = 0.002

# Create optimizers for the discriminator and generator
d_optimizer = optim.Adam(D.parameters(), lr)
g_optimizer = optim.Adam(G.parameters(), lr)

import pickle as pkl

# training hyperparams
num_epochs = 100

# keep track of loss and generated, "fake" samples
samples = []
losses = []

print_every = 400

# Get some fixed data for sampling. These are images that are held
# constant throughout training, and allow us to inspect the model's performance
sample_size=16
fixed_z = np.random.uniform(-1, 1, size=(sample_size, z_size))
fixed_z = torch.from_numpy(fixed_z).float()

# train the network
D.train()
G.train()
for epoch in range(num_epochs):
    
    for batch_i, (real_images, _) in enumerate(train_loader):
                
        batch_size = real_images.size(0)
        
        ## Important rescaling step ## 
        real_images = real_images*2 - 1  # rescale input images from [0,1) to [-1, 1)
        
        # ============================================
        #            TRAIN THE DISCRIMINATOR
        # ============================================
        
        d_optimizer.zero_grad()
        
        # 1. Train with real images

        # Compute the discriminator losses on real images 
        # smooth the real labels
        D_real = D(real_images)
        d_real_loss = real_loss(D_real, smooth=True)
        
        # 2. Train with fake images
        
        # Generate fake images
        z = np.random.uniform(-1, 1, size=(batch_size, z_size))
        z = torch.from_numpy(z).float()
        fake_images = G(z)
        
        # Compute the discriminator losses on fake images        
        D_fake = D(fake_images)
        d_fake_loss = fake_loss(D_fake)
        
        # add up loss and perform backprop
        d_loss = d_real_loss + d_fake_loss
        d_loss.backward()
        d_optimizer.step()
        
        
        # =========================================
        #            TRAIN THE GENERATOR
        # =========================================
        g_optimizer.zero_grad()
        
        # 1. Train with fake images and flipped labels
        
        # Generate fake images
        z = np.random.uniform(-1, 1, size=(batch_size, z_size))
        z = torch.from_numpy(z).float()
        fake_images = G(z)
        
        # Compute the discriminator losses on fake images 
        # using flipped labels!
        D_fake = D(fake_images)
        g_loss = real_loss(D_fake) # use real loss to flip labels
        
        # perform backprop
        g_loss.backward()
        g_optimizer.step()

        # Print some loss stats
        if batch_i % print_every == 0:
            # print discriminator and generator loss
            print('Epoch [{:5d}/{:5d}] | d_loss: {:6.4f} | g_loss: {:6.4f}'.format(
                    epoch+1, num_epochs, d_loss.item(), g_loss.item()))

    
    ## AFTER EACH EPOCH##
    # append discriminator loss and generator loss
    losses.append((d_loss.item(), g_loss.item()))
    
    # generate and save sample, fake images
    G.eval() # eval mode for generating samples
    samples_z = G(fixed_z)
    samples.append(samples_z)
    G.train() # back to train mode


# Save training generator samples
with open('train_samples.pkl', 'wb') as f:
    pkl.dump(samples, f)

训练损失如下:

Epoch [    1/  100] | d_loss: 1.4183 | g_loss: 0.6135
Epoch [    1/  100] | d_loss: 0.9797 | g_loss: 3.5072
Epoch [    1/  100] | d_loss: 0.6198 | g_loss: 5.3375
Epoch [    2/  100] | d_loss: 9.4996 | g_loss: 0.4747
Epoch [    2/  100] | d_loss: 1.0309 | g_loss: 2.0515
Epoch [    2/  100] | d_loss: 0.9327 | g_loss: 1.7079
Epoch [    3/  100] | d_loss: 2.1240 | g_loss: 1.3280
Epoch [    3/  100] | d_loss: 0.9323 | g_loss: 1.2268
Epoch [    3/  100] | d_loss: 1.1284 | g_loss: 1.5946
Epoch [    4/  100] | d_loss: 2.1666 | g_loss: 3.2032
Epoch [    4/  100] | d_loss: 0.7609 | g_loss: 1.6455
Epoch [    4/  100] | d_loss: 0.9046 | g_loss: 3.0293
Epoch [    5/  100] | d_loss: 1.7139 | g_loss: 1.2261
Epoch [    5/  100] | d_loss: 0.7299 | g_loss: 10.3912
Epoch [    5/  100] | d_loss: 0.3566 | g_loss: 27.5909
Epoch [    6/  100] | d_loss: 6.6589 | g_loss: 6.5818
Epoch [    6/  100] | d_loss: 0.4291 | g_loss: 14.8082
Epoch [    6/  100] | d_loss: 0.3489 | g_loss: 17.6337
Epoch [    7/  100] | d_loss: 3.3547 | g_loss: 9.0117
Epoch [    7/  100] | d_loss: 0.3846 | g_loss: 19.2056
Epoch [    7/  100] | d_loss: 0.3466 | g_loss: 17.6365
Epoch [    8/  100] | d_loss: 0.3601 | g_loss: 17.7147
Epoch [    8/  100] | d_loss: 0.4702 | g_loss: 26.0832
Epoch [    8/  100] | d_loss: 0.3763 | g_loss: 20.9755
Epoch [    9/  100] | d_loss: 0.6864 | g_loss: 11.5328
Epoch [    9/  100] | d_loss: 0.4102 | g_loss: 15.1188
Epoch [    9/  100] | d_loss: 0.3541 | g_loss: 18.3997
Epoch [   10/  100] | d_loss: 0.6527 | g_loss: 13.2949
Epoch [   10/  100] | d_loss: 0.4098 | g_loss: 12.7677
Epoch [   10/  100] | d_loss: 0.4056 | g_loss: 14.1539
Epoch [   11/  100] | d_loss: 0.9110 | g_loss: 18.9935
Epoch [   11/  100] | d_loss: 0.3591 | g_loss: 11.9452
Epoch [   11/  100] | d_loss: 0.3530 | g_loss: 19.7258
Epoch [   12/  100] | d_loss: 2.2365 | g_loss: 12.2991
Epoch [   12/  100] | d_loss: 0.3457 | g_loss: 17.7986
Epoch [   12/  100] | d_loss: 0.3648 | g_loss: 19.3954
Epoch [   13/  100] | d_loss: 2.0952 | g_loss: 12.1697
Epoch [   13/  100] | d_loss: 0.3805 | g_loss: 9.8293
Epoch [   13/  100] | d_loss: 0.3445 | g_loss: 20.9563
Epoch [   14/  100] | d_loss: 0.8317 | g_loss: 17.5827
Epoch [   14/  100] | d_loss: 0.3667 | g_loss: 16.4045
Epoch [   14/  100] | d_loss: 0.3387 | g_loss: 18.5352
Epoch [   15/  100] | d_loss: 0.7394 | g_loss: 16.3790
Epoch [   15/  100] | d_loss: 0.3785 | g_loss: 12.4679
Epoch [   15/  100] | d_loss: 0.3351 | g_loss: 21.7534
Epoch [   16/  100] | d_loss: 1.0355 | g_loss: 15.6396
Epoch [   16/  100] | d_loss: 0.3729 | g_loss: 7.6380
Epoch [   16/  100] | d_loss: 0.3675 | g_loss: 15.9038
Epoch [   17/  100] | d_loss: 0.8700 | g_loss: 12.5654
Epoch [   17/  100] | d_loss: 0.4025 | g_loss: 9.8733
Epoch [   17/  100] | d_loss: 0.3840 | g_loss: 16.0042
Epoch [   18/  100] | d_loss: 0.7950 | g_loss: 13.7673
Epoch [   18/  100] | d_loss: 0.3454 | g_loss: 11.1624
Epoch [   18/  100] | d_loss: 0.3722 | g_loss: 14.9801
Epoch [   19/  100] | d_loss: 0.4957 | g_loss: 13.4140
Epoch [   19/  100] | d_loss: 0.3782 | g_loss: 10.0773
Epoch [   19/  100] | d_loss: 0.3422 | g_loss: 18.9227
Epoch [   20/  100] | d_loss: 0.3525 | g_loss: 15.9695
Epoch [   20/  100] | d_loss: 0.3397 | g_loss: 12.6761
Epoch [   20/  100] | d_loss: 0.3352 | g_loss: 18.0243
Epoch [   21/  100] | d_loss: 0.9658 | g_loss: 14.0450
Epoch [   21/  100] | d_loss: 0.3488 | g_loss: 12.1490
Epoch [   21/  100] | d_loss: 0.3339 | g_loss: 17.1810
Epoch [   22/  100] | d_loss: 0.6433 | g_loss: 14.0821
Epoch [   22/  100] | d_loss: 0.3382 | g_loss: 16.5094
Epoch [   22/  100] | d_loss: 0.3544 | g_loss: 19.1663
Epoch [   23/  100] | d_loss: 1.5242 | g_loss: 10.9259
Epoch [   23/  100] | d_loss: 0.3641 | g_loss: 14.2070
Epoch [   23/  100] | d_loss: 0.3361 | g_loss: 19.6924
Epoch [   24/  100] | d_loss: 0.8881 | g_loss: 16.3108
Epoch [   24/  100] | d_loss: 0.4295 | g_loss: 10.7050
Epoch [   24/  100] | d_loss: 0.3383 | g_loss: 15.5534
Epoch [   25/  100] | d_loss: 0.4760 | g_loss: 16.3258
Epoch [   25/  100] | d_loss: 0.3646 | g_loss: 11.3346
Epoch [   25/  100] | d_loss: 0.3497 | g_loss: 17.2252
Epoch [   26/  100] | d_loss: 0.4799 | g_loss: 18.8594
Epoch [   26/  100] | d_loss: 0.3793 | g_loss: 10.4789
Epoch [   26/  100] | d_loss: 0.3610 | g_loss: 18.0723
Epoch [   27/  100] | d_loss: 0.5806 | g_loss: 16.5386
Epoch [   27/  100] | d_loss: 0.4071 | g_loss: 11.2509
Epoch [   27/  100] | d_loss: 0.3344 | g_loss: 19.1435
Epoch [   28/  100] | d_loss: 0.7388 | g_loss: 15.6390
Epoch [   28/  100] | d_loss: 0.3862 | g_loss: 10.3167
Epoch [   28/  100] | d_loss: 0.3493 | g_loss: 21.1869
Epoch [   29/  100] | d_loss: 1.6755 | g_loss: 10.4002
Epoch [   29/  100] | d_loss: 0.3782 | g_loss: 9.7869
Epoch [   29/  100] | d_loss: 0.3460 | g_loss: 17.8775
Epoch [   30/  100] | d_loss: 0.8180 | g_loss: 11.2492
Epoch [   30/  100] | d_loss: 0.4636 | g_loss: 7.5997
Epoch [   30/  100] | d_loss: 0.3632 | g_loss: 15.7323
Epoch [   31/  100] | d_loss: 0.6391 | g_loss: 15.8171
Epoch [   31/  100] | d_loss: 0.3859 | g_loss: 12.1420
Epoch [   31/  100] | d_loss: 0.3341 | g_loss: 11.3105
Epoch [   32/  100] | d_loss: 0.6047 | g_loss: 13.8931
Epoch [   32/  100] | d_loss: 0.4110 | g_loss: 9.5100
Epoch [   32/  100] | d_loss: 0.3708 | g_loss: 15.7434
Epoch [   33/  100] | d_loss: 0.7056 | g_loss: 13.9771
Epoch [   33/  100] | d_loss: 0.4062 | g_loss: 11.1046
Epoch [   33/  100] | d_loss: 0.3318 | g_loss: 14.1029
Epoch [   34/  100] | d_loss: 1.2167 | g_loss: 11.7245
Epoch [   34/  100] | d_loss: 0.4461 | g_loss: 7.3524
Epoch [   34/  100] | d_loss: 0.3491 | g_loss: 13.6828
Epoch [   35/  100] | d_loss: 1.0649 | g_loss: 11.6849
Epoch [   35/  100] | d_loss: 0.4522 | g_loss: 8.2442
Epoch [   35/  100] | d_loss: 0.3453 | g_loss: 15.2782
Epoch [   36/  100] | d_loss: 1.7150 | g_loss: 10.2848
Epoch [   36/  100] | d_loss: 0.4152 | g_loss: 7.9770
Epoch [   36/  100] | d_loss: 0.3413 | g_loss: 15.9781
Epoch [   37/  100] | d_loss: 1.4314 | g_loss: 13.1972
Epoch [   37/  100] | d_loss: 0.4598 | g_loss: 9.6094
Epoch [   37/  100] | d_loss: 0.3980 | g_loss: 15.6485
Epoch [   38/  100] | d_loss: 1.4494 | g_loss: 11.0654
Epoch [   38/  100] | d_loss: 0.4490 | g_loss: 9.5278
Epoch [   38/  100] | d_loss: 0.3500 | g_loss: 14.6877
Epoch [   39/  100] | d_loss: 1.5234 | g_loss: 13.8237
Epoch [   39/  100] | d_loss: 0.3995 | g_loss: 11.1140
Epoch [   39/  100] | d_loss: 0.4001 | g_loss: 14.4916
Epoch [   40/  100] | d_loss: 2.0880 | g_loss: 9.9507
Epoch [   40/  100] | d_loss: 0.4193 | g_loss: 10.2406
Epoch [   40/  100] | d_loss: 0.3678 | g_loss: 14.5487
Epoch [   41/  100] | d_loss: 2.2361 | g_loss: 10.1447
Epoch [   41/  100] | d_loss: 0.3992 | g_loss: 11.9671
Epoch [   41/  100] | d_loss: 0.4041 | g_loss: 18.2163
Epoch [   42/  100] | d_loss: 1.5434 | g_loss: 10.6550
Epoch [   42/  100] | d_loss: 0.4292 | g_loss: 8.3420
Epoch [   42/  100] | d_loss: 0.3707 | g_loss: 17.8904
Epoch [   43/  100] | d_loss: 1.4780 | g_loss: 14.2423
Epoch [   43/  100] | d_loss: 0.4213 | g_loss: 15.0119
Epoch [   43/  100] | d_loss: 0.3488 | g_loss: 14.9278
Epoch [   44/  100] | d_loss: 1.3850 | g_loss: 14.2280
Epoch [   44/  100] | d_loss: 0.3986 | g_loss: 8.7429
Epoch [   44/  100] | d_loss: 0.4616 | g_loss: 17.7188
Epoch [   45/  100] | d_loss: 1.0322 | g_loss: 11.6225
Epoch [   45/  100] | d_loss: 0.4518 | g_loss: 8.2310
Epoch [   45/  100] | d_loss: 0.5395 | g_loss: 13.5624
Epoch [   46/  100] | d_loss: 0.8925 | g_loss: 11.7434
Epoch [   46/  100] | d_loss: 0.4946 | g_loss: 9.6476
Epoch [   46/  100] | d_loss: 0.3597 | g_loss: 14.2146
Epoch [   47/  100] | d_loss: 1.6359 | g_loss: 9.1033
Epoch [   47/  100] | d_loss: 0.4684 | g_loss: 10.1571
Epoch [   47/  100] | d_loss: 0.3915 | g_loss: 12.1220
Epoch [   48/  100] | d_loss: 1.6015 | g_loss: 10.8328
Epoch [   48/  100] | d_loss: 0.4896 | g_loss: 12.4736
Epoch [   48/  100] | d_loss: 0.3711 | g_loss: 16.8187
Epoch [   49/  100] | d_loss: 1.7252 | g_loss: 13.1232
Epoch [   49/  100] | d_loss: 0.3986 | g_loss: 12.4418
Epoch [   49/  100] | d_loss: 0.3884 | g_loss: 15.6806
Epoch [   50/  100] | d_loss: 0.6736 | g_loss: 13.3662
Epoch [   50/  100] | d_loss: 0.4813 | g_loss: 8.3349
Epoch [   50/  100] | d_loss: 0.3528 | g_loss: 15.8811
Epoch [   51/  100] | d_loss: 0.5697 | g_loss: 10.5772
Epoch [   51/  100] | d_loss: 0.3887 | g_loss: 11.2442
Epoch [   51/  100] | d_loss: 0.3533 | g_loss: 21.8093
Epoch [   52/  100] | d_loss: 1.1224 | g_loss: 13.7392
Epoch [   52/  100] | d_loss: 0.4418 | g_loss: 9.4389
Epoch [   52/  100] | d_loss: 0.3438 | g_loss: 13.8667
Epoch [   53/  100] | d_loss: 0.9507 | g_loss: 13.5954
Epoch [   53/  100] | d_loss: 0.3734 | g_loss: 10.2547
Epoch [   53/  100] | d_loss: 0.3774 | g_loss: 16.8306
Epoch [   54/  100] | d_loss: 1.2391 | g_loss: 14.0591
Epoch [   54/  100] | d_loss: 0.4399 | g_loss: 15.4043
Epoch [   54/  100] | d_loss: 0.3683 | g_loss: 13.9010
Epoch [   55/  100] | d_loss: 1.0226 | g_loss: 13.8092
Epoch [   55/  100] | d_loss: 0.4058 | g_loss: 12.3307
Epoch [   55/  100] | d_loss: 0.3373 | g_loss: 14.5292
Epoch [   56/  100] | d_loss: 0.7458 | g_loss: 12.7015
Epoch [   56/  100] | d_loss: 0.6104 | g_loss: 11.0592
Epoch [   56/  100] | d_loss: 0.3599 | g_loss: 14.5782
Epoch [   57/  100] | d_loss: 0.7517 | g_loss: 14.1672
Epoch [   57/  100] | d_loss: 0.4376 | g_loss: 11.6036
Epoch [   57/  100] | d_loss: 0.3649 | g_loss: 11.8032
Epoch [   58/  100] | d_loss: 0.8386 | g_loss: 15.6734
Epoch [   58/  100] | d_loss: 0.6015 | g_loss: 10.0157
Epoch [   58/  100] | d_loss: 0.3409 | g_loss: 12.7498
Epoch [   59/  100] | d_loss: 0.7816 | g_loss: 12.8772
Epoch [   59/  100] | d_loss: 0.3815 | g_loss: 13.5657
Epoch [   59/  100] | d_loss: 0.3553 | g_loss: 16.9895
Epoch [   60/  100] | d_loss: 1.8105 | g_loss: 13.1576
Epoch [   60/  100] | d_loss: 0.5160 | g_loss: 14.5819
Epoch [   60/  100] | d_loss: 0.4257 | g_loss: 19.7234
Epoch [   61/  100] | d_loss: 0.9425 | g_loss: 11.9711
Epoch [   61/  100] | d_loss: 0.4487 | g_loss: 10.1208
Epoch [   61/  100] | d_loss: 0.3690 | g_loss: 12.6821
Epoch [   62/  100] | d_loss: 1.1703 | g_loss: 14.9203
Epoch [   62/  100] | d_loss: 0.4267 | g_loss: 12.9844
Epoch [   62/  100] | d_loss: 0.3522 | g_loss: 13.4276
Epoch [   63/  100] | d_loss: 1.1755 | g_loss: 14.1319
Epoch [   63/  100] | d_loss: 0.4458 | g_loss: 13.5662
Epoch [   63/  100] | d_loss: 0.3534 | g_loss: 14.9882
Epoch [   64/  100] | d_loss: 1.0760 | g_loss: 14.4738
Epoch [   64/  100] | d_loss: 0.4565 | g_loss: 12.1190
Epoch [   64/  100] | d_loss: 0.3553 | g_loss: 15.1464
Epoch [   65/  100] | d_loss: 0.8142 | g_loss: 11.0828
Epoch [   65/  100] | d_loss: 0.4330 | g_loss: 14.1257
Epoch [   65/  100] | d_loss: 0.3599 | g_loss: 15.4030
Epoch [   66/  100] | d_loss: 0.6770 | g_loss: 13.1254
Epoch [   66/  100] | d_loss: 0.5536 | g_loss: 13.1232
Epoch [   66/  100] | d_loss: 0.3725 | g_loss: 12.5105
Epoch [   67/  100] | d_loss: 0.6099 | g_loss: 12.6813
Epoch [   67/  100] | d_loss: 0.6119 | g_loss: 11.8948
Epoch [   67/  100] | d_loss: 0.3643 | g_loss: 13.8359
Epoch [   68/  100] | d_loss: 0.8602 | g_loss: 13.3447
Epoch [   68/  100] | d_loss: 0.6922 | g_loss: 16.1481
Epoch [   68/  100] | d_loss: 0.3522 | g_loss: 15.6391
Epoch [   69/  100] | d_loss: 1.3279 | g_loss: 13.6985
Epoch [   69/  100] | d_loss: 0.4594 | g_loss: 13.2752
Epoch [   69/  100] | d_loss: 0.3806 | g_loss: 11.9751
Epoch [   70/  100] | d_loss: 0.7558 | g_loss: 13.9299
Epoch [   70/  100] | d_loss: 0.4122 | g_loss: 14.1247
Epoch [   70/  100] | d_loss: 0.3731 | g_loss: 14.1739
Epoch [   71/  100] | d_loss: 0.6188 | g_loss: 15.3189
Epoch [   71/  100] | d_loss: 0.4105 | g_loss: 16.1061
Epoch [   71/  100] | d_loss: 0.3339 | g_loss: 16.0961
Epoch [   72/  100] | d_loss: 0.7347 | g_loss: 14.8844
Epoch [   72/  100] | d_loss: 0.6053 | g_loss: 11.8251
Epoch [   72/  100] | d_loss: 0.4238 | g_loss: 14.9223
Epoch [   73/  100] | d_loss: 1.1834 | g_loss: 15.3078
Epoch [   73/  100] | d_loss: 0.5181 | g_loss: 12.1415
Epoch [   73/  100] | d_loss: 0.3445 | g_loss: 14.8677
Epoch [   74/  100] | d_loss: 0.8008 | g_loss: 13.1710
Epoch [   74/  100] | d_loss: 0.4513 | g_loss: 15.9446
Epoch [   74/  100] | d_loss: 0.4414 | g_loss: 20.3710
Epoch [   75/  100] | d_loss: 1.2803 | g_loss: 15.4193
Epoch [   75/  100] | d_loss: 0.5942 | g_loss: 10.7087
Epoch [   75/  100] | d_loss: 0.3437 | g_loss: 17.9181
Epoch [   76/  100] | d_loss: 1.0286 | g_loss: 12.8885
Epoch [   76/  100] | d_loss: 0.4920 | g_loss: 18.6149
Epoch [   76/  100] | d_loss: 0.3751 | g_loss: 17.4362
Epoch [   77/  100] | d_loss: 0.9973 | g_loss: 17.9688
Epoch [   77/  100] | d_loss: 0.4018 | g_loss: 15.8639
Epoch [   77/  100] | d_loss: 0.3448 | g_loss: 16.8935
Epoch [   78/  100] | d_loss: 1.0238 | g_loss: 14.3503
Epoch [   78/  100] | d_loss: 0.4841 | g_loss: 15.9601
Epoch [   78/  100] | d_loss: 0.3357 | g_loss: 15.6731
Epoch [   79/  100] | d_loss: 0.9124 | g_loss: 14.2594
Epoch [   79/  100] | d_loss: 0.3902 | g_loss: 14.8787
Epoch [   79/  100] | d_loss: 0.3434 | g_loss: 17.8774
Epoch [   80/  100] | d_loss: 0.7669 | g_loss: 14.9693
Epoch [   80/  100] | d_loss: 0.3639 | g_loss: 20.3815
Epoch [   80/  100] | d_loss: 0.3442 | g_loss: 13.0722
Epoch [   81/  100] | d_loss: 1.0707 | g_loss: 14.3246
Epoch [   81/  100] | d_loss: 0.4075 | g_loss: 17.0175
Epoch [   81/  100] | d_loss: 0.3697 | g_loss: 18.8611
Epoch [   82/  100] | d_loss: 1.0980 | g_loss: 14.5868
Epoch [   82/  100] | d_loss: 0.4924 | g_loss: 19.5658
Epoch [   82/  100] | d_loss: 0.3424 | g_loss: 12.9414
Epoch [   83/  100] | d_loss: 0.9112 | g_loss: 17.5380
Epoch [   83/  100] | d_loss: 0.3837 | g_loss: 20.3259
Epoch [   83/  100] | d_loss: 0.3333 | g_loss: 18.1492
Epoch [   84/  100] | d_loss: 0.6574 | g_loss: 16.3081
Epoch [   84/  100] | d_loss: 0.4217 | g_loss: 17.3476
Epoch [   84/  100] | d_loss: 0.3554 | g_loss: 18.8399
Epoch [   85/  100] | d_loss: 0.8820 | g_loss: 17.5911
Epoch [   85/  100] | d_loss: 0.4262 | g_loss: 16.1993
Epoch [   85/  100] | d_loss: 0.3485 | g_loss: 14.7559
Epoch [   86/  100] | d_loss: 0.4581 | g_loss: 16.3607
Epoch [   86/  100] | d_loss: 0.4308 | g_loss: 16.9571
Epoch [   86/  100] | d_loss: 0.3995 | g_loss: 15.4270
Epoch [   87/  100] | d_loss: 0.6950 | g_loss: 14.6800
Epoch [   87/  100] | d_loss: 0.4876 | g_loss: 19.5503
Epoch [   87/  100] | d_loss: 0.3796 | g_loss: 14.6323
Epoch [   88/  100] | d_loss: 0.8164 | g_loss: 16.1642
Epoch [   88/  100] | d_loss: 0.4529 | g_loss: 17.5407
Epoch [   88/  100] | d_loss: 0.3481 | g_loss: 15.0995
Epoch [   89/  100] | d_loss: 0.4253 | g_loss: 14.2475
Epoch [   89/  100] | d_loss: 0.4663 | g_loss: 20.7928
Epoch [   89/  100] | d_loss: 0.3493 | g_loss: 12.6848
Epoch [   90/  100] | d_loss: 0.6324 | g_loss: 14.3659
Epoch [   90/  100] | d_loss: 0.3564 | g_loss: 14.4839
Epoch [   90/  100] | d_loss: 0.3456 | g_loss: 13.2325
Epoch [   91/  100] | d_loss: 0.4724 | g_loss: 12.8062
Epoch [   91/  100] | d_loss: 0.4218 | g_loss: 15.9515
Epoch [   91/  100] | d_loss: 0.3742 | g_loss: 14.1208
Epoch [   92/  100] | d_loss: 0.5828 | g_loss: 14.9528
Epoch [   92/  100] | d_loss: 0.4663 | g_loss: 17.1954
Epoch [   92/  100] | d_loss: 0.3966 | g_loss: 15.6063
Epoch [   93/  100] | d_loss: 0.3877 | g_loss: 18.0134
Epoch [   93/  100] | d_loss: 0.4227 | g_loss: 16.2912
Epoch [   93/  100] | d_loss: 0.3556 | g_loss: 14.5541
Epoch [   94/  100] | d_loss: 0.3899 | g_loss: 17.4970
Epoch [   94/  100] | d_loss: 0.5194 | g_loss: 17.1904
Epoch [   94/  100] | d_loss: 0.3436 | g_loss: 12.9894
Epoch [   95/  100] | d_loss: 0.6292 | g_loss: 17.2345
Epoch [   95/  100] | d_loss: 0.4242 | g_loss: 16.9986
Epoch [   95/  100] | d_loss: 0.3520 | g_loss: 16.3274
Epoch [   96/  100] | d_loss: 0.4138 | g_loss: 12.0211
Epoch [   96/  100] | d_loss: 0.3920 | g_loss: 16.1397
Epoch [   96/  100] | d_loss: 0.3473 | g_loss: 17.8238
Epoch [   97/  100] | d_loss: 0.7130 | g_loss: 16.0466
Epoch [   97/  100] | d_loss: 0.4132 | g_loss: 17.5734
Epoch [   97/  100] | d_loss: 0.3350 | g_loss: 15.7057
Epoch [   98/  100] | d_loss: 0.6129 | g_loss: 18.6339
Epoch [   98/  100] | d_loss: 0.3829 | g_loss: 17.1675
Epoch [   98/  100] | d_loss: 0.3422 | g_loss: 17.0065
Epoch [   99/  100] | d_loss: 0.6298 | g_loss: 16.4161
Epoch [   99/  100] | d_loss: 0.3821 | g_loss: 16.3967
Epoch [   99/  100] | d_loss: 0.3362 | g_loss: 14.8846
Epoch [  100/  100] | d_loss: 0.4717 | g_loss: 16.5175
Epoch [  100/  100] | d_loss: 0.4092 | g_loss: 18.5726
Epoch [  100/  100] | d_loss: 0.3331 | g_loss: 15.4866
# 绘图
fig, ax = plt.subplots()
losses = np.array(losses)
plt.plot(losses.T[0], label='Discriminator')
plt.plot(losses.T[1], label='Generator')
plt.title("Training Losses")
plt.legend()

image
image
生成网络效果有点差

posted @ 2024-07-20 18:49  无问夕故  阅读(3)  评论(0编辑  收藏  举报