对抗学习DCGAN网络

DCGAN教程

1. 简介

本教程通过一个例子来对 DCGANs 进行介绍。我们将会训练一个生成对抗网络(GAN)用于在展示了许多真正的名人的图片后产生新的名人。 这里的大部分代码来自pytorch/examples中的dcgan 实现,本文档将对实现进行进行全面 的介绍,并阐明该模型的工作原理以及为什么如此。但是不需要担心,你并不需要事先了解 GAN,但可能需要花一些事件来推理一下底层实际发生的事情。此外,为了有助于节省时间,最好是使用一个GPU,或者两个。让我们从头开始。

2. 生成对抗网络(Generative Adversarial Networks)

2.1 什么是 GAN

GANs是用于 DL (Deep Learning)模型去捕获训练数据分布情况的框架,以此我们可以从同一分布中生成新的数据。GANs是有Ian Goodfellow 于2014年提出,并且首次在论文GenerativeAdversarial Nets中描述。它们由两个不同的模型组成,一个是生成器一个是判别器。生成器的工作是产生看起来像训练图像的“假”图像;判别器的工作是 查看图像并输出它是否是真实的训练图像或来自生成器的伪图像。在训练期间,产生器不断尝试通过产生越来越好的假动作来超越判别器, 而判别器则是为了更好地检测并准确地对真实和假图像进行分类。这个游戏的平衡是当生成器产生完美的假动作以使假图像看起来像是来自 训练数据,而判别器总是猜测生成器输出图像为真或假的概率为50%。

现在,我们开始定义在这个教程中使用到的一些符号。

  • 判别器的符号定义

x表示代表一张图像的数据,D(x) 是判别器网络,它输出x 来自训练数据而不是生成器的(标量)概率。这里,由于我们处理图像, D(x) 的输入是 CHW 大小为3x64x64的图像。直观地,当来自训练数据时D(x) 应该是 HIGH ,而当x 来自生成器时应该是 LOW。D(x) 也可以被认为是传统的二元分类器。

  • 生成器的符号定义

对于生成器的符号,让z 是从标准正态分布中采样的潜在空间矢量,G(z) 表示将潜在向量 映射到数据空间的生成器函数,G 的目标是估计训练数据来自什么分布(Pdata),以便它可以 根据估计的分布(Pg)生成假样本。

因此,D(G(z)) 是生成器的输出是真实图像的概率(标量)。正如Goodfellow 的论文中所描述的,DG 玩一个极小极大的游戏,其中D 试图最大化它正确地分类真实数据和假样本的概率,并且G 试图最小化D 预测其输出是假的概率log(1-D(G(x)))。从论文来看,GAN 损失函数是:
在这里插入图片描述
理论上,这个极小极大游戏的解决方案是Pg = Pdata,如果输入的是真实的或假的,则判别器会随机猜测。然而,GAN 的收敛理论仍在积极研究中,实际上模型并不总是训练到这一点。

2.2 什么是 DCGAN

DCGAN 是上述 GAN 的直接扩展,区别的是它分别在判别器和生成器中明确地使用了卷积和卷积转置层。它首先是由Radford等人在论文Unsupervised Representation Learning With DeepConvolutional Generative Adversarial Networks中提出。判别器由 strided convolution layers、batch norm layers 和 LeakyReLU activations 组成,它输入 3x64x64 的图像, 然后输出的是一个代表输入是来自实际数据分布的标量概率。生成器则是由 convolutional-transpose layers、 batchnorm layers 和 ReLU activations 组成。它的输入是从标准正态分布中绘制的潜在向量,输出是3x64x64 的 RGB 图像。strided conv-transpose layers 允许潜在标量转换成具有与图像相同形状的体积。在本文中,作者还提供了一些有关如何设置优化器,如何计算损失函数以及如何初始化 模型权重的提示,所有这些都将在后面的章节中进行说明。

from __future__ import print_function
#%matplotlib inline
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

# Set random seed for reproducibility
manualSeed = 999
#manualSeed = random.randint(1, 10000) # 如果你想要新的结果就是要这段代码
print("Random Seed: ", manualSeed)
# 以下两个都是方便复现结果而设置随机数
random.seed(manualSeed)
torch.manual_seed(manualSeed)

输出结果:

Random Seed:  999

3. DCGAN实现过程

3.1 输入

让我们定义输入数据去运行我们的教程:

  • dataroot:存放数据集根目录的路径。我们将在下一节中详细讨论数据集
  • workers:使用DataLoader加载数据的工作线程数
  • batch_size:训练中使用的batch大小。在DCGAN论文中batch的大小为128
  • image_size:用于训练的图像的空间大小。此实现默认 64x64。如果需要其他尺寸,则必须改变和 的结构。有关详细信息,请参见此处
  • nc:输入图像中的颜色通道数。对于彩色图像,这是参数设置为3
  • nz:潜在向量的长度
  • ngf:与通过生成器携带的特征图的深度有关
  • ndf:设置通过判别器传播的特征映射的深度
  • num_epochs:要运行的训练的epoch数量。长时间的训练可能会带来更好的结果,但也需要更长的时间
  • lr:学习速率。如DCGAN论文中所述,此数字应为0.0002
  • beta1:适用于Adam优化器的beta1超参数。如论文所述,此数字应为0.5
  • ngpu:可用的GPU数量。如果为0,则代码将以CPU模式运行。如果此数字大于0,它将在该数量的GPU上运行
# Root directory for dataset
dataroot = "data/celeba"
# Number of workers for dataloader
workers = 2
# Batch size during training
batch_size = 128
# Spatial size of training images. All images will be resized to this
#   size using a transformer.
image_size = 64
# Number of channels in the training images. For color images this is 3
nc = 3
# Size of z latent vector (i.e. size of generator input)
nz = 100
# Size of feature maps in generator
ngf = 64
# Size of feature maps in discriminator
ndf = 64
# Number of training epochs
num_epochs = 5
# Learning rate for optimizers
lr = 0.0002
# Beta1 hyperparam for Adam optimizers
beta1 = 0.5
# Number of GPUs available. Use 0 for CPU mode.
ngpu = 1

3.2 数据

在本教程中,我们将使用Celeb-A Faces数据集,该数据集可以在链接或Google Drive中下载。数据集将下载为名为img_align_celeba.zip的文件。下载后,创建名为celeba的目录并将zip文件解压缩到该目录中。然后,将此笔记中的数据对象输入设置为刚刚创建的celeba目录。生成的目录结构应该是:

/path/to/celeba
    -> img_align_celeba
        -> 188242.jpg
        -> 173822.jpg
        -> 284702.jpg
        -> 537394.jpg
           ...

这是一个重要的步骤,因为我们将使用ImageFolder数据集类,它要求在数据集的根文件夹中有子目录。现在,我们可以创建数据集,创 建数据加载器,设置要运行的设备,以及最后可视化一些训练数据。

# 我们可以按照设置的方式使用图像文件夹数据集。
# 创建数据集
dataset = dset.ImageFolder(root=dataroot,
                           transform=transforms.Compose([
                               transforms.Resize(image_size),
                               transforms.CenterCrop(image_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
# # 创建加载器
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=workers)

# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

# Plot some training images
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))

在这里插入图片描述

3.3 实现

通过设置输入参数和准备好的数据集,我们现在可以进入真正的实现步骤。我们将从权重初始化策略开始,然后详细讨论生成器,鉴别器, 损失函数和训练循环。

3.3.1 权重初始化

在DCGAN论文中,作者指出所有模型权重应从正态分布中随机初始化,mean = 0,stdev = 0.02。weights_init函数将初始化模型作为 输入,并重新初始化所有卷积,卷积转置和batch标准化层以满足此标准。初始化后立即将此函数应用于模型。

# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)
3.3.2 生成器

生成器用于将潜在空间矢量(z)映射到数据空间。由于我们的数据是图像,因此将 转换为数据空间意味着最终创建与训练图像具有相同大小的RGB图像(即3x64x64)。实际上,这是通过一系列跨步的二维卷积转置层实现的, 每个转换层与二维批量标准层和relu activation进行配对。生成器的输出通过tanh函数输入,使其返回到[-1,1]范围的输入数据。值得 注意的是在转换层之后存在批量范数函数,因为这是DCGAN论文的关键贡献。这些层有助于训练期间的梯度流动。DCGAN论文中的生成器中 的图像如下所示:
在这里插入图片描述

请注意,我们对输入怎么设置(nz,ngfnc)会影响代码中的生成器体系结构。nz 是输入向量的长度, ngf与通过生成器传播的特征图的大小有关,nc是输出图像中的通道数(对于RGB图像,设置为3)。下面是生成器的代码。

  • 生成器代码
# Generator Code

class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

    def forward(self, input):
        return self.main(input)

现在,我们可以实例化生成器并应用weights_init函数。查看打印的模型以查看生成器对象的结构。

# Create the generator
netG = Generator(ngpu).to(device)

# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
    netG = nn.DataParallel(netG, list(range(ngpu)))

# Apply the weights_init function to randomly initialize all weights
#  to mean=0, stdev=0.02.
netG.apply(weights_init)

# Print the model
print(netG)
  • 输出结果:
Generator(
  (main): Sequential(
    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh()
  )
)
3.3.3 判别器

如上所述,判别器是二进制分类网络,它将图像作为输入并输出输入图像是真实的标量概率(与假的相反)。这里, 采用 3x64x64 的输入图像,通过一系列Conv2d,BatchNorm2d和LeakyReLU层处理它,并通过Sigmoid激活函数输出 最终概率。如果问题需要,可以使用更多层扩展此体系结构,但使用strided convolution(跨步卷积),BatchNorm和LeakyReLU具有重要 意义。DCGAN论文提到使用跨步卷积而不是池化到降低采样是一种很好的做法,因为它可以让网络学习自己的池化功能。批量标准和 leaky relu函数也促进良好的梯度流,这对于和的学习过程都是至关重要的。

  • 判别器代码
class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

现在,与生成器一样,我们可以创建判别器,应用weights_init函数,并打印模型的结构。

# Create the Discriminator
netD = Discriminator(ngpu).to(device)

# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
    netD = nn.DataParallel(netD, list(range(ngpu)))

# Apply the weights_init function to randomly initialize all weights
#  to mean=0, stdev=0.2.
netD.apply(weights_init)

# Print the model
print(netD)
  • 输出结果
Discriminator(
  (main): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (12): Sigmoid()
  )
)
3.3.4 损失函数和优化器

通过D和G设置,我们可以指定他们如何通过损失函数和优化器学习。我们将使用PyTorch中定义的 二进制交叉熵损失()BCELoss函数:
在这里插入图片描述
注意该函数如何计算目标函数中的两个对数分量(即logD(x)和log(1-D(G(z)))。我们可以指定用于输入y 的BCE方程的哪个部分。这是在即将出现的训练循环中完成的,但重要的是要了解我们如何通过改变y(即GT标签)来选择我们希望计算的组件。

接下来,我们将真实标签定义为1,将假标签定义为0。这些标签将在计算和的损失时使用,这 也是原始 GAN 论文中使用的惯例。最后,我们设置了两个单独的优化器,一个用于D,一个用于G。 如 DCGAN 论文中所述,两者都是Adam优化器,学习率为0.0002,Beta1 = 0.5。 为了跟踪生成器的学习进度,我们将生成一组固定的潜在 向量,这些向量是从高斯分布(即fixed_noise)中提取的。在训练循环中,我们将周期性地将此fixed_noise输入到 中,并且在迭代中我们将看到图像形成于噪声之外。

# Initialize BCELoss function
criterion = nn.BCELoss()

# Create batch of latent vectors that we will use to visualize
#  the progression of the generator
fixed_noise = torch.randn(64, nz, 1, 1, device=device)

# Establish convention for real and fake labels during training
real_label = 1.
fake_label = 0.

# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
3.3.4 训练

最后,既然已经定义了 GAN 框架的所有部分,我们就可以对其进行训练了。请注意,训练GAN在某种程度上是一种艺术形式,因为不正确 的超参数设置会导致对错误的解释很少的模式崩溃,在这里,我们将密切关注Goodfellow的论文中的算法1,同时遵守ganhacks 中展示的一些最佳实践。也就是说,我们将“为真实和虚假”图像构建不同的 mini-batches ,并且还调整的目标函 数以最大化。训练分为两个主要部分,第1部分更新判别器,第2部分更新生成器。

Part 1 - Train the Discriminator训练判别器

回想一下,训练判别器的目的是最大化将给定输入正确分类为真实或假的概率。就Goodfellow而言,我们希望“通过提升其随机梯度来更新判别器”。实际上,我们希望最大化log(D(x))+log(1-D(G(z)))。由于ganhacks的独立 mini-batch 建议,我们将分两步计算。首先,我们将从训练集构建一批实际样本,向前通过D,计算损失log(D(x)),然后计算向后传递的梯度。其次,我们将用当前生成器构造一批假样本,通过D向前传递该 batch,计算损失log(1-D(G(z))), 并通过反向传递累积梯度。现在,随着从全实时和全实时批量累积的梯度,我们称之为Discriminator优化器的一步。

Part 2 - Train the Generator训练生成器

正如原始论文所述,我们希望通过最小化log(1-D(G(z)))来训练生成器,以便产生更好的伪样本。如上所述,Goodfellow 表明这不会提供足够的梯度,尤其是在学习过程的早期阶段。作为修复,我们希望最大化log(D(G(z)))。在代码中,我们通过 以下方式实现此目的:使用判别器对第1部分的生成器中的输出进行分类,使用真实标签: GT 标签计算G的损失, 在向后传递中计算的梯度,最后使用优化器步骤更新G的参数。使用真实标签作为损失函数的GT 标签似乎是违反直觉的,但是这允许我们使用 BCELoss的log(x)部分(而不是log(1-x)部分), 这正是我们想要。

最后,我们将进行一些统计报告,在每个epoch结束时,我们将通过生成器推送我们的fixed_noisebatch,以直观地跟踪G训练的进度。训练的统计数据是:

  • Loss_D: 判别器损失计算为所有实际批次和所有假批次的损失总和log(D(x)) + log(D(G(z)))
  • Loss_G: 计算生成器损失log(D(G(z)))
  • D(x): 所有实际批次的判别器的平均输出(整批)。当G变好时这应该从接近1开始,然后理论上收敛到0.5。
  • D(G(z)): 所有假批次的平均判别器输出。第一个数字是在D更新之前,第二个数字是在D更新之后。当G变好时,这些数字应该从0开始并收敛到0.5。

此步骤可能需要一段时间,具体取决于您运行的epoch数以及是否从数据集中删除了一些数据。

# Training Loop

# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0

print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
    # For each batch in the dataloader
    for i, data in enumerate(dataloader, 0):

        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        netD.zero_grad()
        # Format batch
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
        # Forward pass real batch through D
        output = netD(real_cpu).view(-1)
        # Calculate loss on all-real batch
        errD_real = criterion(output, label)
        # Calculate gradients for D in backward pass
        errD_real.backward()
        D_x = output.mean().item()

        ## Train with all-fake batch
        # Generate batch of latent vectors
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        # Generate fake image batch with G
        fake = netG(noise)
        label.fill_(fake_label)
        # Classify all fake batch with D
        output = netD(fake.detach()).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = criterion(output, label)
        # Calculate the gradients for this batch, accumulated (summed) with previous gradients
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        # Compute error of D as sum over the fake and the real batches
        errD = errD_real + errD_fake
        # Update D
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = netD(fake).view(-1)
        # Calculate G's loss based on this output
        errG = criterion(output, label)
        # Calculate gradients for G
        errG.backward()
        D_G_z2 = output.mean().item()
        # Update G
        optimizerG.step()

        # Output training stats
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())

        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

        iters += 1

输出结果:

Starting Training Loop...
[0/5][0/1583]   Loss_D: 1.6280  Loss_G: 5.5230  D(x): 0.5728    D(G(z)): 0.5501 / 0.0065
[0/5][50/1583]  Loss_D: 0.2304  Loss_G: 14.5537 D(x): 0.8902    D(G(z)): 0.0000 / 0.0000
[0/5][100/1583] Loss_D: 0.4667  Loss_G: 7.5298  D(x): 0.8012    D(G(z)): 0.0140 / 0.0019
[0/5][150/1583] Loss_D: 1.6321  Loss_G: 9.5917  D(x): 0.9843    D(G(z)): 0.7314 / 0.0004
[0/5][200/1583] Loss_D: 0.7978  Loss_G: 9.0436  D(x): 0.9225    D(G(z)): 0.4432 / 0.0007
[0/5][250/1583] Loss_D: 0.4024  Loss_G: 3.9679  D(x): 0.8135    D(G(z)): 0.0361 / 0.0447
[0/5][300/1583] Loss_D: 0.5149  Loss_G: 4.1895  D(x): 0.7881    D(G(z)): 0.1290 / 0.0292
[0/5][350/1583] Loss_D: 0.6168  Loss_G: 2.8078  D(x): 0.6550    D(G(z)): 0.0332 / 0.0885
[0/5][400/1583] Loss_D: 0.5454  Loss_G: 5.6257  D(x): 0.8727    D(G(z)): 0.2721 / 0.0091
[0/5][450/1583] Loss_D: 0.6010  Loss_G: 5.7853  D(x): 0.8588    D(G(z)): 0.2870 / 0.0087
[0/5][500/1583] Loss_D: 0.8000  Loss_G: 6.2724  D(x): 0.8474    D(G(z)): 0.4012 / 0.0040
[0/5][550/1583] Loss_D: 0.5219  Loss_G: 5.4993  D(x): 0.9168    D(G(z)): 0.3143 / 0.0073
[0/5][600/1583] Loss_D: 1.6073  Loss_G: 11.3235 D(x): 0.9741    D(G(z)): 0.6769 / 0.0001
[0/5][650/1583] Loss_D: 0.6905  Loss_G: 4.1023  D(x): 0.7250    D(G(z)): 0.2099 / 0.0327
[0/5][700/1583] Loss_D: 0.6160  Loss_G: 4.4970  D(x): 0.7873    D(G(z)): 0.1839 / 0.0265
[0/5][750/1583] Loss_D: 0.5432  Loss_G: 5.1344  D(x): 0.8419    D(G(z)): 0.2339 / 0.0105
[0/5][800/1583] Loss_D: 0.3196  Loss_G: 4.4716  D(x): 0.8096    D(G(z)): 0.0459 / 0.0201
[0/5][850/1583] Loss_D: 0.5087  Loss_G: 3.9806  D(x): 0.8488    D(G(z)): 0.2323 / 0.0337
[0/5][900/1583] Loss_D: 0.6441  Loss_G: 7.9880  D(x): 0.8974    D(G(z)): 0.3452 / 0.0012
[0/5][950/1583] Loss_D: 2.3273  Loss_G: 2.3773  D(x): 0.2035    D(G(z)): 0.0051 / 0.2024
[0/5][1000/1583]        Loss_D: 0.5630  Loss_G: 5.5355  D(x): 0.8624    D(G(z)): 0.2793 / 0.0089
[0/5][1050/1583]        Loss_D: 0.3751  Loss_G: 3.0624  D(x): 0.7592    D(G(z)): 0.0251 / 0.0705
[0/5][1100/1583]        Loss_D: 0.5808  Loss_G: 3.5756  D(x): 0.6847    D(G(z)): 0.0547 / 0.0510
[0/5][1150/1583]        Loss_D: 0.2253  Loss_G: 5.0725  D(x): 0.8796    D(G(z)): 0.0617 / 0.0132
[0/5][1200/1583]        Loss_D: 0.3175  Loss_G: 3.5829  D(x): 0.8602    D(G(z)): 0.1207 / 0.0436
[0/5][1250/1583]        Loss_D: 0.6691  Loss_G: 4.1399  D(x): 0.8065    D(G(z)): 0.2861 / 0.0291
[0/5][1300/1583]        Loss_D: 0.7186  Loss_G: 7.0522  D(x): 0.9462    D(G(z)): 0.4039 / 0.0021
[0/5][1350/1583]        Loss_D: 1.2718  Loss_G: 1.5942  D(x): 0.4068    D(G(z)): 0.0257 / 0.2724
[0/5][1400/1583]        Loss_D: 0.4084  Loss_G: 4.6125  D(x): 0.8783    D(G(z)): 0.2062 / 0.0171
[0/5][1450/1583]        Loss_D: 0.7464  Loss_G: 2.6160  D(x): 0.6230    D(G(z)): 0.1065 / 0.1111
[0/5][1500/1583]        Loss_D: 0.3353  Loss_G: 5.1631  D(x): 0.9297    D(G(z)): 0.1955 / 0.0109
[0/5][1550/1583]        Loss_D: 0.6005  Loss_G: 5.6019  D(x): 0.9059    D(G(z)): 0.3550 / 0.0062
[1/5][0/1583]   Loss_D: 0.4403  Loss_G: 3.4381  D(x): 0.8170    D(G(z)): 0.1487 / 0.0521
[1/5][50/1583]  Loss_D: 0.7481  Loss_G: 3.8258  D(x): 0.7920    D(G(z)): 0.3092 / 0.0402
[1/5][100/1583] Loss_D: 0.8191  Loss_G: 5.7194  D(x): 0.9222    D(G(z)): 0.4570 / 0.0071
[1/5][150/1583] Loss_D: 0.7062  Loss_G: 5.9147  D(x): 0.9072    D(G(z)): 0.4055 / 0.0052
[1/5][200/1583] Loss_D: 0.4160  Loss_G: 2.8901  D(x): 0.7845    D(G(z)): 0.1103 / 0.0784
[1/5][250/1583] Loss_D: 1.4022  Loss_G: 2.1375  D(x): 0.3602    D(G(z)): 0.0050 / 0.2027
[1/5][300/1583] Loss_D: 0.6249  Loss_G: 4.3820  D(x): 0.9234    D(G(z)): 0.3688 / 0.0230
[1/5][350/1583] Loss_D: 0.6409  Loss_G: 1.8723  D(x): 0.6454    D(G(z)): 0.0936 / 0.2279
[1/5][400/1583] Loss_D: 0.4709  Loss_G: 4.3662  D(x): 0.8718    D(G(z)): 0.2391 / 0.0233
[1/5][450/1583] Loss_D: 0.7901  Loss_G: 5.0924  D(x): 0.9571    D(G(z)): 0.4457 / 0.0157
[1/5][500/1583] Loss_D: 0.4266  Loss_G: 3.8325  D(x): 0.8907    D(G(z)): 0.2367 / 0.0340
[1/5][550/1583] Loss_D: 0.8709  Loss_G: 3.5440  D(x): 0.9120    D(G(z)): 0.4439 / 0.0568
[1/5][600/1583] Loss_D: 0.6492  Loss_G: 5.0122  D(x): 0.9129    D(G(z)): 0.3796 / 0.0126
[1/5][650/1583] Loss_D: 0.3493  Loss_G: 3.6410  D(x): 0.8787    D(G(z)): 0.1753 / 0.0361
[1/5][700/1583] Loss_D: 0.4558  Loss_G: 3.1112  D(x): 0.8550    D(G(z)): 0.2272 / 0.0613
[1/5][750/1583] Loss_D: 1.2227  Loss_G: 0.3276  D(x): 0.3891    D(G(z)): 0.0125 / 0.7497
[1/5][800/1583] Loss_D: 0.5045  Loss_G: 3.1462  D(x): 0.7709    D(G(z)): 0.1693 / 0.0641
[1/5][850/1583] Loss_D: 0.7897  Loss_G: 1.5936  D(x): 0.5447    D(G(z)): 0.0447 / 0.2645
[1/5][900/1583] Loss_D: 0.8545  Loss_G: 4.2344  D(x): 0.9237    D(G(z)): 0.4734 / 0.0229
[1/5][950/1583] Loss_D: 0.5500  Loss_G: 3.3831  D(x): 0.7885    D(G(z)): 0.1998 / 0.0612
[1/5][1000/1583]        Loss_D: 0.5960  Loss_G: 4.6958  D(x): 0.9005    D(G(z)): 0.3427 / 0.0150
[1/5][1050/1583]        Loss_D: 0.5912  Loss_G: 2.6033  D(x): 0.6482    D(G(z)): 0.0663 / 0.1031
[1/5][1100/1583]        Loss_D: 0.9515  Loss_G: 5.2913  D(x): 0.8725    D(G(z)): 0.4687 / 0.0092
[1/5][1150/1583]        Loss_D: 0.5580  Loss_G: 3.2972  D(x): 0.8814    D(G(z)): 0.3090 / 0.0537
[1/5][1200/1583]        Loss_D: 1.1461  Loss_G: 0.9476  D(x): 0.4131    D(G(z)): 0.0189 / 0.4413
[1/5][1250/1583]        Loss_D: 0.6168  Loss_G: 4.0134  D(x): 0.8580    D(G(z)): 0.3055 / 0.0269
[1/5][1300/1583]        Loss_D: 0.5301  Loss_G: 3.8858  D(x): 0.9201    D(G(z)): 0.3122 / 0.0296
[1/5][1350/1583]        Loss_D: 0.6131  Loss_G: 2.8594  D(x): 0.7811    D(G(z)): 0.2471 / 0.0791
[1/5][1400/1583]        Loss_D: 0.9951  Loss_G: 1.7822  D(x): 0.6227    D(G(z)): 0.3034 / 0.2146
[1/5][1450/1583]        Loss_D: 0.4611  Loss_G: 2.5508  D(x): 0.7326    D(G(z)): 0.0763 / 0.1128
[1/5][1500/1583]        Loss_D: 0.4965  Loss_G: 2.6068  D(x): 0.8356    D(G(z)): 0.2374 / 0.1011
[1/5][1550/1583]        Loss_D: 1.4584  Loss_G: 4.7512  D(x): 0.9392    D(G(z)): 0.6799 / 0.0160
[2/5][0/1583]   Loss_D: 0.4193  Loss_G: 2.6057  D(x): 0.7777    D(G(z)): 0.1259 / 0.1003
[2/5][50/1583]  Loss_D: 0.5835  Loss_G: 3.2384  D(x): 0.8721    D(G(z)): 0.3276 / 0.0510
[2/5][100/1583] Loss_D: 0.5853  Loss_G: 3.2728  D(x): 0.8013    D(G(z)): 0.2637 / 0.0529
[2/5][150/1583] Loss_D: 0.6370  Loss_G: 3.5247  D(x): 0.8720    D(G(z)): 0.3603 / 0.0426
[2/5][200/1583] Loss_D: 0.7683  Loss_G: 4.1482  D(x): 0.9096    D(G(z)): 0.4515 / 0.0274
[2/5][250/1583] Loss_D: 0.4617  Loss_G: 2.7579  D(x): 0.8163    D(G(z)): 0.1902 / 0.0831
[2/5][300/1583] Loss_D: 0.6355  Loss_G: 4.1777  D(x): 0.9490    D(G(z)): 0.4036 / 0.0238
[2/5][350/1583] Loss_D: 0.6522  Loss_G: 2.2862  D(x): 0.7462    D(G(z)): 0.2419 / 0.1343
[2/5][400/1583] Loss_D: 1.2008  Loss_G: 4.1554  D(x): 0.8870    D(G(z)): 0.6028 / 0.0254
[2/5][450/1583] Loss_D: 0.7300  Loss_G: 1.5709  D(x): 0.7383    D(G(z)): 0.2852 / 0.2654
[2/5][500/1583] Loss_D: 0.5974  Loss_G: 2.1474  D(x): 0.6345    D(G(z)): 0.0660 / 0.1613
[2/5][550/1583] Loss_D: 0.5099  Loss_G: 3.7225  D(x): 0.8902    D(G(z)): 0.2988 / 0.0324
[2/5][600/1583] Loss_D: 1.3572  Loss_G: 5.6093  D(x): 0.9563    D(G(z)): 0.6824 / 0.0059
[2/5][650/1583] Loss_D: 0.6382  Loss_G: 1.6281  D(x): 0.6235    D(G(z)): 0.0915 / 0.2419
[2/5][700/1583] Loss_D: 0.6769  Loss_G: 3.1289  D(x): 0.8553    D(G(z)): 0.3436 / 0.0771
[2/5][750/1583] Loss_D: 0.9135  Loss_G: 2.7677  D(x): 0.7087    D(G(z)): 0.3676 / 0.0852
[2/5][800/1583] Loss_D: 0.5959  Loss_G: 2.5488  D(x): 0.7787    D(G(z)): 0.2497 / 0.1019
[2/5][850/1583] Loss_D: 0.6833  Loss_G: 2.0409  D(x): 0.6336    D(G(z)): 0.1326 / 0.1690
[2/5][900/1583] Loss_D: 0.7005  Loss_G: 1.5728  D(x): 0.5796    D(G(z)): 0.0840 / 0.2479
[2/5][950/1583] Loss_D: 0.5950  Loss_G: 1.8992  D(x): 0.6482    D(G(z)): 0.0816 / 0.1857
[2/5][1000/1583]        Loss_D: 0.7562  Loss_G: 2.1999  D(x): 0.5471    D(G(z)): 0.0380 / 0.1596
[2/5][1050/1583]        Loss_D: 2.1895  Loss_G: 0.9317  D(x): 0.1703    D(G(z)): 0.0295 / 0.4515
[2/5][1100/1583]        Loss_D: 0.9062  Loss_G: 4.2884  D(x): 0.9376    D(G(z)): 0.5196 / 0.0196
[2/5][1150/1583]        Loss_D: 0.7324  Loss_G: 3.7264  D(x): 0.9104    D(G(z)): 0.4157 / 0.0353
[2/5][1200/1583]        Loss_D: 0.4747  Loss_G: 2.2637  D(x): 0.7467    D(G(z)): 0.1359 / 0.1402
[2/5][1250/1583]        Loss_D: 0.5197  Loss_G: 2.8963  D(x): 0.8303    D(G(z)): 0.2524 / 0.0731
[2/5][1300/1583]        Loss_D: 0.5529  Loss_G: 2.9684  D(x): 0.8418    D(G(z)): 0.2874 / 0.0668
[2/5][1350/1583]        Loss_D: 0.8456  Loss_G: 2.8286  D(x): 0.7357    D(G(z)): 0.3612 / 0.0818
[2/5][1400/1583]        Loss_D: 0.6503  Loss_G: 3.2244  D(x): 0.8971    D(G(z)): 0.3705 / 0.0547
[2/5][1450/1583]        Loss_D: 1.9256  Loss_G: 5.8077  D(x): 0.9719    D(G(z)): 0.7875 / 0.0058
[2/5][1500/1583]        Loss_D: 0.5821  Loss_G: 3.0173  D(x): 0.7724    D(G(z)): 0.2264 / 0.0656
[2/5][1550/1583]        Loss_D: 0.6522  Loss_G: 1.6340  D(x): 0.6143    D(G(z)): 0.0920 / 0.2473
[3/5][0/1583]   Loss_D: 1.7587  Loss_G: 0.3800  D(x): 0.2138    D(G(z)): 0.0124 / 0.7149
[3/5][50/1583]  Loss_D: 0.6285  Loss_G: 1.7274  D(x): 0.7018    D(G(z)): 0.1955 / 0.2135
[3/5][100/1583] Loss_D: 1.5421  Loss_G: 5.0975  D(x): 0.9517    D(G(z)): 0.7230 / 0.0099
[3/5][150/1583] Loss_D: 0.5885  Loss_G: 2.8417  D(x): 0.8255    D(G(z)): 0.2855 / 0.0797
[3/5][200/1583] Loss_D: 0.5335  Loss_G: 1.5141  D(x): 0.7147    D(G(z)): 0.1451 / 0.2559
[3/5][250/1583] Loss_D: 0.8832  Loss_G: 3.5427  D(x): 0.8948    D(G(z)): 0.4498 / 0.0455
[3/5][300/1583] Loss_D: 0.5486  Loss_G: 1.7794  D(x): 0.7131    D(G(z)): 0.1533 / 0.2037
[3/5][350/1583] Loss_D: 0.7178  Loss_G: 2.7095  D(x): 0.7989    D(G(z)): 0.3475 / 0.0913
[3/5][400/1583] Loss_D: 0.5809  Loss_G: 1.7390  D(x): 0.7196    D(G(z)): 0.1781 / 0.2194
[3/5][450/1583] Loss_D: 0.8536  Loss_G: 3.5861  D(x): 0.8830    D(G(z)): 0.4716 / 0.0416
[3/5][500/1583] Loss_D: 0.7952  Loss_G: 1.4783  D(x): 0.5296    D(G(z)): 0.0673 / 0.2896
[3/5][550/1583] Loss_D: 2.9799  Loss_G: 6.5446  D(x): 0.9852    D(G(z)): 0.8837 / 0.0029
[3/5][600/1583] Loss_D: 0.7653  Loss_G: 2.8613  D(x): 0.8142    D(G(z)): 0.3827 / 0.0772
[3/5][650/1583] Loss_D: 0.6686  Loss_G: 2.8368  D(x): 0.8508    D(G(z)): 0.3486 / 0.0794
[3/5][700/1583] Loss_D: 0.4576  Loss_G: 2.5251  D(x): 0.7882    D(G(z)): 0.1727 / 0.1000
[3/5][750/1583] Loss_D: 0.6471  Loss_G: 2.6482  D(x): 0.8044    D(G(z)): 0.2935 / 0.0969
[3/5][800/1583] Loss_D: 0.6649  Loss_G: 1.5387  D(x): 0.6034    D(G(z)): 0.0943 / 0.2517
[3/5][850/1583] Loss_D: 1.0407  Loss_G: 3.3057  D(x): 0.8926    D(G(z)): 0.5609 / 0.0498
[3/5][900/1583] Loss_D: 0.5002  Loss_G: 2.1851  D(x): 0.7248    D(G(z)): 0.1248 / 0.1486
[3/5][950/1583] Loss_D: 0.8142  Loss_G: 1.4080  D(x): 0.5191    D(G(z)): 0.0439 / 0.2979
[3/5][1000/1583]        Loss_D: 4.1161  Loss_G: 0.7369  D(x): 0.0288    D(G(z)): 0.0051 / 0.5529
[3/5][1050/1583]        Loss_D: 0.4978  Loss_G: 2.3390  D(x): 0.7696    D(G(z)): 0.1818 / 0.1195
[3/5][1100/1583]        Loss_D: 0.6442  Loss_G: 2.4792  D(x): 0.7098    D(G(z)): 0.2003 / 0.1116
[3/5][1150/1583]        Loss_D: 0.8305  Loss_G: 1.7473  D(x): 0.5973    D(G(z)): 0.1954 / 0.2186
[3/5][1200/1583]        Loss_D: 0.6362  Loss_G: 2.8213  D(x): 0.7560    D(G(z)): 0.2582 / 0.0749
[3/5][1250/1583]        Loss_D: 0.8410  Loss_G: 4.0181  D(x): 0.9007    D(G(z)): 0.4703 / 0.0267
[3/5][1300/1583]        Loss_D: 0.5280  Loss_G: 1.9016  D(x): 0.6709    D(G(z)): 0.0732 / 0.1838
[3/5][1350/1583]        Loss_D: 0.7867  Loss_G: 1.2578  D(x): 0.5556    D(G(z)): 0.0909 / 0.3339
[3/5][1400/1583]        Loss_D: 0.6212  Loss_G: 2.1267  D(x): 0.6733    D(G(z)): 0.1587 / 0.1568
[3/5][1450/1583]        Loss_D: 0.7302  Loss_G: 3.3170  D(x): 0.8947    D(G(z)): 0.4171 / 0.0463
[3/5][1500/1583]        Loss_D: 0.8142  Loss_G: 2.0735  D(x): 0.7116    D(G(z)): 0.3122 / 0.1594
[3/5][1550/1583]        Loss_D: 0.5078  Loss_G: 3.0857  D(x): 0.8632    D(G(z)): 0.2742 / 0.0612
[4/5][0/1583]   Loss_D: 0.6122  Loss_G: 3.1503  D(x): 0.8047    D(G(z)): 0.2904 / 0.0611
[4/5][50/1583]  Loss_D: 0.6374  Loss_G: 3.5315  D(x): 0.9105    D(G(z)): 0.3751 / 0.0438
[4/5][100/1583] Loss_D: 0.6631  Loss_G: 2.0490  D(x): 0.6502    D(G(z)): 0.1412 / 0.1698
[4/5][150/1583] Loss_D: 1.9588  Loss_G: 0.3717  D(x): 0.2208    D(G(z)): 0.0454 / 0.7271
[4/5][200/1583] Loss_D: 0.9403  Loss_G: 3.1536  D(x): 0.7958    D(G(z)): 0.4448 / 0.0616
[4/5][250/1583] Loss_D: 0.5813  Loss_G: 1.7517  D(x): 0.6633    D(G(z)): 0.1200 / 0.2110
[4/5][300/1583] Loss_D: 0.6037  Loss_G: 2.6966  D(x): 0.7774    D(G(z)): 0.2648 / 0.0862
[4/5][350/1583] Loss_D: 0.5110  Loss_G: 2.6701  D(x): 0.7676    D(G(z)): 0.1830 / 0.0945
[4/5][400/1583] Loss_D: 0.7620  Loss_G: 1.6387  D(x): 0.5405    D(G(z)): 0.0434 / 0.2531
[4/5][450/1583] Loss_D: 0.7376  Loss_G: 1.6293  D(x): 0.5784    D(G(z)): 0.0967 / 0.2386
[4/5][500/1583] Loss_D: 0.7782  Loss_G: 0.9095  D(x): 0.5480    D(G(z)): 0.0823 / 0.4508
[4/5][550/1583] Loss_D: 0.5541  Loss_G: 3.0480  D(x): 0.8816    D(G(z)): 0.3181 / 0.0604
[4/5][600/1583] Loss_D: 0.7756  Loss_G: 2.7610  D(x): 0.7517    D(G(z)): 0.3436 / 0.0776
[4/5][650/1583] Loss_D: 0.5770  Loss_G: 2.9508  D(x): 0.8607    D(G(z)): 0.3109 / 0.0729
[4/5][700/1583] Loss_D: 0.5425  Loss_G: 3.2729  D(x): 0.8931    D(G(z)): 0.3136 / 0.0509
[4/5][750/1583] Loss_D: 0.8674  Loss_G: 1.8879  D(x): 0.5125    D(G(z)): 0.0612 / 0.2064
[4/5][800/1583] Loss_D: 0.5066  Loss_G: 2.6371  D(x): 0.8471    D(G(z)): 0.2595 / 0.0989
[4/5][850/1583] Loss_D: 0.7728  Loss_G: 3.1481  D(x): 0.8230    D(G(z)): 0.3779 / 0.0620
[4/5][900/1583] Loss_D: 0.5977  Loss_G: 3.3086  D(x): 0.8532    D(G(z)): 0.3189 / 0.0471
[4/5][950/1583] Loss_D: 0.4288  Loss_G: 2.8824  D(x): 0.8516    D(G(z)): 0.2069 / 0.0766
[4/5][1000/1583]        Loss_D: 0.5891  Loss_G: 2.9834  D(x): 0.8491    D(G(z)): 0.3117 / 0.0667
[4/5][1050/1583]        Loss_D: 0.9709  Loss_G: 3.8847  D(x): 0.8438    D(G(z)): 0.4910 / 0.0278
[4/5][1100/1583]        Loss_D: 0.4382  Loss_G: 2.6705  D(x): 0.8662    D(G(z)): 0.2277 / 0.0912
[4/5][1150/1583]        Loss_D: 0.5530  Loss_G: 1.9816  D(x): 0.7754    D(G(z)): 0.2229 / 0.1753
[4/5][1200/1583]        Loss_D: 0.4281  Loss_G: 3.0041  D(x): 0.8170    D(G(z)): 0.1765 / 0.0664
[4/5][1250/1583]        Loss_D: 0.6017  Loss_G: 1.6758  D(x): 0.6972    D(G(z)): 0.1774 / 0.2295
[4/5][1300/1583]        Loss_D: 3.0598  Loss_G: 5.4203  D(x): 0.9827    D(G(z)): 0.9255 / 0.0091
[4/5][1350/1583]        Loss_D: 0.5566  Loss_G: 2.4434  D(x): 0.8195    D(G(z)): 0.2647 / 0.1139
[4/5][1400/1583]        Loss_D: 0.5802  Loss_G: 1.6129  D(x): 0.6773    D(G(z)): 0.1315 / 0.2377
[4/5][1450/1583]        Loss_D: 0.6888  Loss_G: 3.6280  D(x): 0.9017    D(G(z)): 0.4098 / 0.0345
[4/5][1500/1583]        Loss_D: 0.5164  Loss_G: 2.0016  D(x): 0.7806    D(G(z)): 0.2010 / 0.1690
[4/5][1550/1583]        Loss_D: 0.9306  Loss_G: 1.3858  D(x): 0.4671    D(G(z)): 0.0457 / 0.2987
3.3.5 结果

最后,让我们看看我们是如何做到的。在这里,我们将看看三个不同的结果。首先,我们将看到D和 G的损失在训练期间是如何变化的。其次,我们将可视化在每个epoch的 fixed_noise batch中的输出。第三,我们将 查看来自G的紧邻一批实际数据的一批假数据。

损失与训练迭代

下面是D&G的损失与训练迭代的关系图。

plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

在这里插入图片描述
G的过程可视化
记住在每个训练epoch之后我们如何在fixed_noise batch中保存生成器的输出。现在,我们可以通过动画可视化G的训练进度。按播放按钮 开始动画。

#%%capture
fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())

在这里插入图片描述
真实图像 vs 伪图像

最后,让我们一起看看一些真实的图像和伪图像。

# 从数据加载器中获取一批真实图像
real_batch = next(iter(dataloader))

# 绘制真实图像
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))

# 在最后一个epoch中绘制伪图像
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()

在这里插入图片描述

posted @ 2022-03-22 21:18  小Aer  阅读(17)  评论(0编辑  收藏  举报  来源