G1、生成对抗网络(GAN)入门
- 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
- 🍖 原作者:K同学啊
- 🚀 文章来源:K同学的学习圈子
本周任务: 📌 基础任务:
- 了解什么是生成对抗网络(GAN)
- 学习本文代码,并跑通代码
🎈进阶任务:
- 调用训练好的模型生成新图像
一、理论基础¶
GAN,Generative Adversarial Networks,也即生成对抗网络。并不指代某一个具体的神经网络,而是指一类基于博弈思想而设计的神经网络。
GAN由两个分别被称为生成器(Generator)和判别器(Discriminator)的神经网络组成。其中,生成器从某种噪声分布中随机采样作为输入,输出与训练集中真实样本非常相似的人工样本;判别器的输入则为真实样本或人工样本,其目的是将人工样本与真实样本尽可能地区分出来。生成器和判别器交替运行,相互博弈,各自的能力都得到升。理想情况下,经过足够次数的博弈之后,判别器无法判断给定样本的真实性,即对于所有样本都输出50%真,50%假的判断。此时,生成器输出的人工样本已经逼真到使判别器无法分辨真假,停止博弈。这样就可以得到一个具有“伪造”真实样本能力的生成器。
1、生成器¶
生成器 G 选取随机噪声 z 作为输入,通过生成器的不断拟合,最终输出一个和真实样本尺寸相同,分布相似的伪造样本G(z)。生成器的本质是一个使用生成式方法的模型,它对数据的分布假设和分布参数进行学习,然后根据学习到的模型重新采样出新的样本。
2、判别器¶
判别器 D 对于输入的样本 x,输出一个[0,1]之间的概率数值D(x)。x 可能是来自于原始数据集中的真实样本 x,也可能是来自于生成器 G 的人工样本G(z)。通常约定,概率值D(x)越接近于1就代表此样本为真实样本的可能性更大;反之概率值越小则此样本为伪造样本的可能性越大。也就是说,这里的判别器是一个二分类的神经网络分类器,目的不是判定输入数据的原始类别,而是区分输入样本的真伪。可以注意到,不管在生成器还是判别器中,样本的类别信息都没有用到,也表明 GAN 是一个无监督的学习过程。
3、基本原理:这一部分也可以看一下李宏毅老师所进行的描述¶
研究者最初想要通过计算机完成自动生成数据的功能,例如通过训练某种算法模型,让某模型学习过一些苹果的图片后能自动生成苹果的图片,具备些功能的算法即认为具有生成功能。但是GAN不是第一个生成算法,而是以往的生成算法在衡量生成图片和真实图片的差距时采用均方误差作为损失函数,但是研究者发现有时均方误差一样的两张生成图片效果却截然不同,鉴于此不足Ian Goodfellow提出了GAN。
GAN是由两个模型组成的:生成模型G和判别模型D。首先第一代生成模型1G的输入是随机噪声z,然后生成模型会生成一张初级照片,训练一代判别模型1D另其进行二分类操作,将生成的图片判别为0,而真实图片判别为1;为了欺瞒一代鉴别器,于是一代生成模型开始优化,然后它进阶成了二代,当它生成的数据成功欺瞒1D时,鉴别模型也会优化更新,进而升级为2D,按照同样的过程也会不断更新出N代的G和D。
二、前期准备¶
1、定义超参数¶
import argparse
import os
import numpy as np
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch.nn as nn
import torch
## 创建文件夹
os.makedirs("./data/images/", exist_ok=True) ## 记录训练过程的图片效果
os.makedirs("./data/save/", exist_ok=True) ## 训练完成时模型保存的位置
os.makedirs("./data/mnist", exist_ok=True) ## 下载数据集存放的位置
## 超参数配置
n_epochs=50
batch_size=64
lr=0.0002
b1=0.5
b2=0.999
n_cpu=2
latent_dim=100
img_size=28
channels=1
sample_interval=500
## 图像的尺寸:(1, 28, 28), 和图像的像素面积:(784)
img_shape = (channels, img_size, img_size)
img_area = np.prod(img_shape)
## 设置cuda:(cuda:0)
cuda = True if torch.cuda.is_available() else False
print(cuda)
True
2、下载数据¶
## mnist数据集下载
mnist = datasets.MNIST(
root='./data/', train=True, download=True, transform=transforms.Compose(
[transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]),
)
3、配置数据¶
## 配置数据到加载器
dataloader = DataLoader(
mnist,
batch_size=batch_size,
shuffle=True,
)
##### 定义判别器 Discriminator ######
## 将图片28x28展开成784,然后通过多层感知器,中间经过斜率设置为0.2的LeakyReLU激活函数,
## 最后接sigmoid激活函数得到一个0到1之间的概率进行二分类
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(img_area, 512), # 输入特征数为784,输出为512
nn.LeakyReLU(0.2, inplace=True), # 进行非线性映射
nn.Linear(512, 256), # 输入特征数为512,输出为256
nn.LeakyReLU(0.2, inplace=True), # 进行非线性映射
nn.Linear(256, 1), # 输入特征数为256,输出为1
nn.Sigmoid(), # sigmoid是一个激活函数,二分类问题中可将实数映射到[0, 1],作为概率值, 多分类用softmax函数
)
def forward(self, img):
img_flat = img.view(img.size(0), -1) # 鉴别器输入是一个被view展开的(784)的一维图像:(64, 784)
validity = self.model(img_flat) # 通过鉴别器网络
return validity # 鉴别器返回的是一个[0, 1]间的概率
2、定义生成器¶
代码定义了一个block函数,用于构建生成器模型中的一块。block函数的输入包括in_feat和out_feat两个参数,分别表示输入特征和输出特征的维度大小。normalize参数表示是否对输出特征进行批正则化,取值为True或False。函数的输出是一个包含若干个层的列表,其中每个层是nn.Linear、nn.BatchNorm1d或nn.LeakyReLU。在函数内部,首先通过nn.Linear将输入特征映射到输出特征,生成一个线性变换层。然后,如果normalize参数为True,则将批正则化层添加到层列表中。BatchNorm1d层是一种批正则化层,可以对数据进行标准化处理。最后,将LeakyReLU层添加到层列表中,用于进行非线性变换。
25-28行代码定义了一个前向传播函数。该函数输入噪声向量z,通过生成器模型得到一个大小为28*28的图像张量。然后使用view方法将图像张量reshape成(batch_size, channel, height, width)的形状,返回生成的图像。
###### 定义生成器 Generator #####
## 输入一个100维的0~1之间的高斯分布,然后通过第一层线性变换将其映射到256维,
## 然后通过LeakyReLU激活函数,接着进行一个线性变换,再经过一个LeakyReLU激活函数,
## 然后经过线性变换将其变成784维,最后经过Tanh激活函数是希望生成的假的图片数据分布, 能够在-1~1之间。
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
## 模型中间块儿
def block(in_feat, out_feat, normalize=True): # block(in, out )
layers = [nn.Linear(in_feat, out_feat)] # 线性变换将输入映射到out维
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8)) # 正则化
layers.append(nn.LeakyReLU(0.2, inplace=True)) # 非线性激活函数
return layers
## prod():返回给定轴上的数组元素的乘积:1*28*28=784
self.model = nn.Sequential(
*block(latent_dim, 128, normalize=False), # 线性变化将输入映射 100 to 128, 正则化, LeakyReLU
*block(128, 256), # 线性变化将输入映射 128 to 256, 正则化, LeakyReLU
*block(256, 512), # 线性变化将输入映射 256 to 512, 正则化, LeakyReLU
*block(512, 1024), # 线性变化将输入映射 512 to 1024, 正则化, LeakyReLU
nn.Linear(1024, img_area), # 线性变化将输入映射 1024 to 784
nn.Tanh() # 将(784)的数据每一个都映射到[-1, 1]之间
)
## view():相当于numpy中的reshape,重新定义矩阵的形状:这里是reshape(64, 1, 28, 28)
def forward(self, z): # 输入的是(64, 100)的噪声数据
imgs = self.model(z) # 噪声数据通过生成器模型
imgs = imgs.view(imgs.size(0), *img_shape) # reshape成(64, 1, 28, 28)
return imgs # 输出为64张大小为(1, 28, 28)的图像
## 创建生成器,判别器对象
generator = Generator()
discriminator = Discriminator()
## 首先需要定义loss的度量方式 (二分类的交叉熵)
criterion = torch.nn.BCELoss()
## 其次定义 优化函数,优化函数的学习率为0.0003
## betas:用于计算梯度以及梯度平方的运行平均值的系数
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))
## 如果有显卡,都在cuda模式中运行
if torch.cuda.is_available():
generator = generator.cuda()
discriminator = discriminator.cuda()
criterion = criterion.cuda()
2、训练模型¶
imgs是一个大小为(64,1,28,28)的张量,表示一个batch中的64张28x28的灰度图像,_表示标签,但在这里并没有用
训练判别器的部分:
- 首先将图像展开为一维张量,大小为(64,784)
- 然后将其转化为PyTorch中的Variable类型,并将其移入GPU计算
- 接着定义real_label为大小为(64,1)的1张量
- 定义fake_label为大小为(64,1)的0张量
- 首先将真实图片放入判别器中,得到real_out,然后计算真实图片的loss,也就是loss_real_D
- real_scores代表真实图片的判别值,越接近1越好
- 接着生成一些随机噪声z,放入生成器中得到一张假的图片fake_img,再将其放入判别器中得到fake_out,计算假的图片的loss,也就是loss_fake_D
- fake_scores代表假的图片的判别值,越接近0越好
- 将真假图片的loss相加得到loss_D,然后对判别器的参数进行反向传播来更新参数
训练生成器的部分:
- 首先生成一些随机噪声z,放入生成器中得到一张假的图片fake_img
- 然后将其放入判别器中得到output,计算假的图片与真实图片的label之间的loss,也就是loss_G
- 对生成器的参数进行反向传播来更新参数。
最后是一些输出和保存产生的图像的代码
## 进行多个epoch的训练
for epoch in range(n_epochs): # epoch:50
for i, (imgs, _) in enumerate(dataloader): # imgs:(64, 1, 28, 28) _:label(64)
## =============================训练判别器==================
## view(): 相当于numpy中的reshape,重新定义矩阵的形状, 相当于reshape(128,784) 原来是(128, 1, 28, 28)
imgs = imgs.view(imgs.size(0), -1) # 将图片展开为28*28=784 imgs:(64, 784)
real_img = Variable(imgs).cuda() # 将tensor变成Variable放入计算图中,tensor变成variable之后才能进行反向传播求梯度
real_label = Variable(torch.ones(imgs.size(0), 1)).cuda() ## 定义真实的图片label为1
fake_label = Variable(torch.zeros(imgs.size(0), 1)).cuda() ## 定义假的图片的label为0
## ---------------------
## Train Discriminator
## 分为两部分:1、真的图像判别为真;2、假的图像判别为假
## ---------------------
## 计算真实图片的损失
real_out = discriminator(real_img) # 将真实图片放入判别器中
loss_real_D = criterion(real_out, real_label) # 得到真实图片的loss
real_scores = real_out # 得到真实图片的判别值,输出的值越接近1越好
## 计算假的图片的损失
## detach(): 从当前计算图中分离下来避免梯度传到G,因为G不用更新
z = Variable(torch.randn(imgs.size(0), latent_dim)).cuda() ## 随机生成一些噪声, 大小为(128, 100)
fake_img = generator(z).detach() ## 随机噪声放入生成网络中,生成一张假的图片。
fake_out = discriminator(fake_img) ## 判别器判断假的图片
loss_fake_D = criterion(fake_out, fake_label) ## 得到假的图片的loss
fake_scores = fake_out ## 得到假图片的判别值,对于判别器来说,假图片的损失越接近0越好
## 损失函数和优化
loss_D = loss_real_D + loss_fake_D # 损失包括判真损失和判假损失
optimizer_D.zero_grad() # 在反向传播之前,先将梯度归0
loss_D.backward() # 将误差反向传播
optimizer_D.step() # 更新参数
## -----------------
## Train Generator
## 原理:目的是希望生成的假的图片被判别器判断为真的图片,
## 在此过程中,将判别器固定,将假的图片传入判别器的结果与真实的label对应,
## 反向传播更新的参数是生成网络里面的参数,
## 这样可以通过更新生成网络里面的参数,来训练网络,使得生成的图片让判别器以为是真的, 这样就达到了对抗的目的
## -----------------
z = Variable(torch.randn(imgs.size(0), latent_dim)).cuda() ## 得到随机噪声
fake_img = generator(z) ## 随机噪声输入到生成器中,得到一副假的图片
output = discriminator(fake_img) ## 经过判别器得到的结果
## 损失函数和优化
loss_G = criterion(output, real_label) ## 得到的假的图片与真实的图片的label的loss
optimizer_G.zero_grad() ## 梯度归0
loss_G.backward() ## 进行反向传播
optimizer_G.step() ## step()一般用在反向传播后面,用于更新生成网络的参数
## 打印训练过程中的日志
## item():取出单元素张量的元素值并返回该值,保持原元素类型不变
if (i + 1) % 300 == 0:
print(
"[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] [D real: %f] [D fake: %f]"
% (epoch, n_epochs, i, len(dataloader), loss_D.item(), loss_G.item(), real_scores.data.mean(), fake_scores.data.mean())
)
## 保存训练过程中的图像
batches_done = epoch * len(dataloader) + i
if batches_done % sample_interval == 0:
save_image(fake_img.data[:25], "./data/images/%d.png" % batches_done, nrow=5, normalize=True)
[Epoch 0/50] [Batch 299/938] [D loss: 1.158173] [G loss: 1.444359] [D real: 0.828766] [D fake: 0.616032]
[Epoch 0/50] [Batch 599/938] [D loss: 0.907697] [G loss: 1.384978] [D real: 0.661816] [D fake: 0.371830]
[Epoch 0/50] [Batch 899/938] [D loss: 0.691578] [G loss: 1.460195] [D real: 0.706499] [D fake: 0.228762]
[Epoch 1/50] [Batch 299/938] [D loss: 1.291033] [G loss: 0.867498] [D real: 0.432650] [D fake: 0.165322]
[Epoch 1/50] [Batch 599/938] [D loss: 0.755460] [G loss: 1.204757] [D real: 0.629975] [D fake: 0.190666]
[Epoch 1/50] [Batch 899/938] [D loss: 1.006618] [G loss: 0.998110] [D real: 0.524488] [D fake: 0.144524]
[Epoch 2/50] [Batch 299/938] [D loss: 1.014850] [G loss: 0.825372] [D real: 0.488002] [D fake: 0.124625]
[Epoch 2/50] [Batch 599/938] [D loss: 1.219018] [G loss: 2.075435] [D real: 0.853672] [D fake: 0.597609]
[Epoch 2/50] [Batch 899/938] [D loss: 0.904595] [G loss: 3.296604] [D real: 0.875925] [D fake: 0.523923]
[Epoch 3/50] [Batch 299/938] [D loss: 1.195196] [G loss: 3.868027] [D real: 0.941523] [D fake: 0.647388]
[Epoch 3/50] [Batch 599/938] [D loss: 0.674789] [G loss: 1.588359] [D real: 0.703144] [D fake: 0.210794]
[Epoch 3/50] [Batch 899/938] [D loss: 1.023985] [G loss: 0.780757] [D real: 0.512141] [D fake: 0.126003]
[Epoch 4/50] [Batch 299/938] [D loss: 0.785617] [G loss: 2.355174] [D real: 0.846768] [D fake: 0.433337]
[Epoch 4/50] [Batch 599/938] [D loss: 0.768934] [G loss: 2.830851] [D real: 0.859101] [D fake: 0.420156]
[Epoch 4/50] [Batch 899/938] [D loss: 0.608886] [G loss: 1.641987] [D real: 0.748333] [D fake: 0.208987]
[Epoch 5/50] [Batch 299/938] [D loss: 0.826799] [G loss: 1.244515] [D real: 0.668695] [D fake: 0.241504]
[Epoch 5/50] [Batch 599/938] [D loss: 0.856476] [G loss: 2.605964] [D real: 0.819840] [D fake: 0.442883]
[Epoch 5/50] [Batch 899/938] [D loss: 0.772790] [G loss: 1.638871] [D real: 0.671070] [D fake: 0.213962]
[Epoch 6/50] [Batch 299/938] [D loss: 0.727868] [G loss: 1.388534] [D real: 0.664138] [D fake: 0.143495]
[Epoch 6/50] [Batch 599/938] [D loss: 0.803381] [G loss: 1.206940] [D real: 0.656227] [D fake: 0.155422]
[Epoch 6/50] [Batch 899/938] [D loss: 0.940307] [G loss: 2.189455] [D real: 0.801064] [D fake: 0.449602]
[Epoch 7/50] [Batch 299/938] [D loss: 0.990136] [G loss: 1.215084] [D real: 0.573626] [D fake: 0.167803]
[Epoch 7/50] [Batch 599/938] [D loss: 1.492306] [G loss: 1.164526] [D real: 0.422325] [D fake: 0.044536]
[Epoch 7/50] [Batch 899/938] [D loss: 0.736784] [G loss: 1.992318] [D real: 0.775619] [D fake: 0.303254]
[Epoch 8/50] [Batch 299/938] [D loss: 0.729828] [G loss: 1.315253] [D real: 0.726814] [D fake: 0.255552]
[Epoch 8/50] [Batch 599/938] [D loss: 0.708870] [G loss: 2.216428] [D real: 0.771318] [D fake: 0.284945]
[Epoch 8/50] [Batch 899/938] [D loss: 0.908683] [G loss: 1.209612] [D real: 0.629250] [D fake: 0.237226]
[Epoch 9/50] [Batch 299/938] [D loss: 0.702113] [G loss: 1.757013] [D real: 0.730061] [D fake: 0.235253]
[Epoch 9/50] [Batch 599/938] [D loss: 0.845079] [G loss: 1.323123] [D real: 0.650204] [D fake: 0.243090]
[Epoch 9/50] [Batch 899/938] [D loss: 0.865847] [G loss: 1.960042] [D real: 0.749930] [D fake: 0.362776]
[Epoch 10/50] [Batch 299/938] [D loss: 0.972551] [G loss: 2.781687] [D real: 0.845141] [D fake: 0.507452]
[Epoch 10/50] [Batch 599/938] [D loss: 1.019926] [G loss: 1.232936] [D real: 0.619500] [D fake: 0.301336]
[Epoch 10/50] [Batch 899/938] [D loss: 0.912966] [G loss: 1.219540] [D real: 0.676508] [D fake: 0.325706]
[Epoch 11/50] [Batch 299/938] [D loss: 0.782303] [G loss: 1.075927] [D real: 0.759116] [D fake: 0.333795]
[Epoch 11/50] [Batch 599/938] [D loss: 0.845553] [G loss: 2.114520] [D real: 0.826822] [D fake: 0.436879]
[Epoch 11/50] [Batch 899/938] [D loss: 1.161359] [G loss: 1.827609] [D real: 0.738133] [D fake: 0.483912]
[Epoch 12/50] [Batch 299/938] [D loss: 0.854428] [G loss: 1.317717] [D real: 0.658240] [D fake: 0.284326]
[Epoch 12/50] [Batch 599/938] [D loss: 1.215137] [G loss: 0.786159] [D real: 0.500593] [D fake: 0.174575]
[Epoch 12/50] [Batch 899/938] [D loss: 0.915575] [G loss: 1.635875] [D real: 0.720458] [D fake: 0.362735]
[Epoch 13/50] [Batch 299/938] [D loss: 1.011057] [G loss: 1.093758] [D real: 0.617855] [D fake: 0.291908]
[Epoch 13/50] [Batch 599/938] [D loss: 0.988704] [G loss: 1.966827] [D real: 0.817004] [D fake: 0.488661]
[Epoch 13/50] [Batch 899/938] [D loss: 0.865117] [G loss: 1.727183] [D real: 0.773818] [D fake: 0.400255]
[Epoch 14/50] [Batch 299/938] [D loss: 0.875381] [G loss: 1.047587] [D real: 0.645928] [D fake: 0.250452]
[Epoch 14/50] [Batch 599/938] [D loss: 0.966645] [G loss: 2.381181] [D real: 0.815996] [D fake: 0.486138]
[Epoch 14/50] [Batch 899/938] [D loss: 0.938718] [G loss: 1.697883] [D real: 0.716241] [D fake: 0.387768]
[Epoch 15/50] [Batch 299/938] [D loss: 1.009761] [G loss: 0.878475] [D real: 0.567891] [D fake: 0.273345]
[Epoch 15/50] [Batch 599/938] [D loss: 0.985754] [G loss: 2.436576] [D real: 0.823689] [D fake: 0.495946]
[Epoch 15/50] [Batch 899/938] [D loss: 1.090001] [G loss: 0.883752] [D real: 0.555214] [D fake: 0.241083]
[Epoch 16/50] [Batch 299/938] [D loss: 1.005893] [G loss: 1.457063] [D real: 0.705683] [D fake: 0.385423]
[Epoch 16/50] [Batch 599/938] [D loss: 0.866947] [G loss: 1.758667] [D real: 0.727871] [D fake: 0.350871]
[Epoch 16/50] [Batch 899/938] [D loss: 0.875345] [G loss: 1.514280] [D real: 0.702745] [D fake: 0.338786]
[Epoch 17/50] [Batch 299/938] [D loss: 0.867365] [G loss: 1.135895] [D real: 0.603121] [D fake: 0.215900]
[Epoch 17/50] [Batch 599/938] [D loss: 0.960292] [G loss: 0.951580] [D real: 0.561749] [D fake: 0.191260]
[Epoch 17/50] [Batch 899/938] [D loss: 0.905297] [G loss: 1.641576] [D real: 0.778182] [D fake: 0.427020]
[Epoch 18/50] [Batch 299/938] [D loss: 0.926826] [G loss: 1.556072] [D real: 0.727871] [D fake: 0.395448]
[Epoch 18/50] [Batch 599/938] [D loss: 0.890065] [G loss: 1.413342] [D real: 0.707900] [D fake: 0.349580]
[Epoch 18/50] [Batch 899/938] [D loss: 1.150915] [G loss: 0.779538] [D real: 0.477192] [D fake: 0.128615]
[Epoch 19/50] [Batch 299/938] [D loss: 0.964692] [G loss: 0.905429] [D real: 0.579390] [D fake: 0.227936]
[Epoch 19/50] [Batch 599/938] [D loss: 1.078362] [G loss: 0.954290] [D real: 0.514275] [D fake: 0.203652]
[Epoch 19/50] [Batch 899/938] [D loss: 0.984906] [G loss: 1.549595] [D real: 0.683712] [D fake: 0.380527]
[Epoch 20/50] [Batch 299/938] [D loss: 0.949209] [G loss: 1.445453] [D real: 0.655572] [D fake: 0.315084]
[Epoch 20/50] [Batch 599/938] [D loss: 1.046732] [G loss: 1.449630] [D real: 0.712719] [D fake: 0.439065]
[Epoch 20/50] [Batch 899/938] [D loss: 0.923628] [G loss: 1.353765] [D real: 0.701917] [D fake: 0.368214]
[Epoch 21/50] [Batch 299/938] [D loss: 0.897119] [G loss: 1.247460] [D real: 0.617709] [D fake: 0.214891]
[Epoch 21/50] [Batch 599/938] [D loss: 0.947255] [G loss: 1.539796] [D real: 0.666425] [D fake: 0.336521]
[Epoch 21/50] [Batch 899/938] [D loss: 1.023242] [G loss: 1.461393] [D real: 0.781023] [D fake: 0.487648]
[Epoch 22/50] [Batch 299/938] [D loss: 0.842644] [G loss: 1.615600] [D real: 0.685891] [D fake: 0.302716]
[Epoch 22/50] [Batch 599/938] [D loss: 1.042079] [G loss: 1.650457] [D real: 0.787996] [D fake: 0.486913]
[Epoch 22/50] [Batch 899/938] [D loss: 1.070047] [G loss: 1.832334] [D real: 0.775443] [D fake: 0.500319]
[Epoch 23/50] [Batch 299/938] [D loss: 1.023204] [G loss: 2.204949] [D real: 0.820360] [D fake: 0.502777]
[Epoch 23/50] [Batch 599/938] [D loss: 0.738861] [G loss: 1.863115] [D real: 0.767690] [D fake: 0.332002]
[Epoch 23/50] [Batch 899/938] [D loss: 0.869603] [G loss: 1.175084] [D real: 0.672111] [D fake: 0.316175]
[Epoch 24/50] [Batch 299/938] [D loss: 0.988703] [G loss: 0.970555] [D real: 0.557572] [D fake: 0.209298]
[Epoch 24/50] [Batch 599/938] [D loss: 1.084172] [G loss: 1.446266] [D real: 0.732659] [D fake: 0.460710]
[Epoch 24/50] [Batch 899/938] [D loss: 0.966943] [G loss: 0.847046] [D real: 0.636745] [D fake: 0.314692]
[Epoch 25/50] [Batch 299/938] [D loss: 1.081174] [G loss: 0.847780] [D real: 0.520214] [D fake: 0.200271]
[Epoch 25/50] [Batch 599/938] [D loss: 1.109172] [G loss: 2.065813] [D real: 0.801831] [D fake: 0.520555]
[Epoch 25/50] [Batch 899/938] [D loss: 0.925075] [G loss: 1.818030] [D real: 0.684656] [D fake: 0.340149]
[Epoch 26/50] [Batch 299/938] [D loss: 0.955453] [G loss: 1.982675] [D real: 0.749039] [D fake: 0.413411]
[Epoch 26/50] [Batch 599/938] [D loss: 0.825978] [G loss: 1.701131] [D real: 0.791001] [D fake: 0.391063]
[Epoch 26/50] [Batch 899/938] [D loss: 0.995865] [G loss: 0.946962] [D real: 0.584566] [D fake: 0.264646]
[Epoch 27/50] [Batch 299/938] [D loss: 0.851348] [G loss: 1.099684] [D real: 0.647326] [D fake: 0.258442]
[Epoch 27/50] [Batch 599/938] [D loss: 0.856978] [G loss: 1.100666] [D real: 0.706608] [D fake: 0.315352]
[Epoch 27/50] [Batch 899/938] [D loss: 1.027624] [G loss: 1.392488] [D real: 0.609727] [D fake: 0.291136]
[Epoch 28/50] [Batch 299/938] [D loss: 1.080262] [G loss: 1.302729] [D real: 0.534590] [D fake: 0.170610]
[Epoch 28/50] [Batch 599/938] [D loss: 1.084942] [G loss: 1.120615] [D real: 0.522578] [D fake: 0.179799]
[Epoch 28/50] [Batch 899/938] [D loss: 1.217424] [G loss: 1.764962] [D real: 0.709627] [D fake: 0.501658]
[Epoch 29/50] [Batch 299/938] [D loss: 1.002094] [G loss: 1.238462] [D real: 0.575900] [D fake: 0.242113]
[Epoch 29/50] [Batch 599/938] [D loss: 1.032944] [G loss: 2.196985] [D real: 0.777151] [D fake: 0.467646]
[Epoch 29/50] [Batch 899/938] [D loss: 0.863847] [G loss: 1.019327] [D real: 0.608237] [D fake: 0.199215]
[Epoch 30/50] [Batch 299/938] [D loss: 0.881623] [G loss: 1.958373] [D real: 0.742719] [D fake: 0.361224]
[Epoch 30/50] [Batch 599/938] [D loss: 0.929857] [G loss: 1.376175] [D real: 0.663011] [D fake: 0.290203]
[Epoch 30/50] [Batch 899/938] [D loss: 0.792349] [G loss: 1.372166] [D real: 0.737139] [D fake: 0.301457]
[Epoch 31/50] [Batch 299/938] [D loss: 0.917897] [G loss: 1.272001] [D real: 0.668212] [D fake: 0.295221]
[Epoch 31/50] [Batch 599/938] [D loss: 0.967251] [G loss: 1.381095] [D real: 0.619703] [D fake: 0.241604]
[Epoch 31/50] [Batch 899/938] [D loss: 1.032823] [G loss: 1.571394] [D real: 0.699449] [D fake: 0.403713]
[Epoch 32/50] [Batch 299/938] [D loss: 1.029269] [G loss: 1.658660] [D real: 0.712622] [D fake: 0.387062]
[Epoch 32/50] [Batch 599/938] [D loss: 0.908382] [G loss: 1.501176] [D real: 0.637434] [D fake: 0.233541]
[Epoch 32/50] [Batch 899/938] [D loss: 0.737346] [G loss: 1.639206] [D real: 0.740658] [D fake: 0.283942]
[Epoch 33/50] [Batch 299/938] [D loss: 0.881072] [G loss: 1.617361] [D real: 0.692938] [D fake: 0.324864]
[Epoch 33/50] [Batch 599/938] [D loss: 0.947872] [G loss: 1.365358] [D real: 0.613924] [D fake: 0.243315]
[Epoch 33/50] [Batch 899/938] [D loss: 1.197088] [G loss: 1.802043] [D real: 0.756234] [D fake: 0.523054]
[Epoch 34/50] [Batch 299/938] [D loss: 1.070333] [G loss: 0.881454] [D real: 0.547635] [D fake: 0.198637]
[Epoch 34/50] [Batch 599/938] [D loss: 0.732026] [G loss: 1.611722] [D real: 0.732818] [D fake: 0.273538]
[Epoch 34/50] [Batch 899/938] [D loss: 0.856358] [G loss: 1.327878] [D real: 0.673174] [D fake: 0.292236]
[Epoch 35/50] [Batch 299/938] [D loss: 0.842997] [G loss: 1.330260] [D real: 0.692341] [D fake: 0.276241]
[Epoch 35/50] [Batch 599/938] [D loss: 1.158882] [G loss: 2.181961] [D real: 0.790613] [D fake: 0.500981]
[Epoch 35/50] [Batch 899/938] [D loss: 0.938851] [G loss: 0.951667] [D real: 0.687770] [D fake: 0.319927]
[Epoch 36/50] [Batch 299/938] [D loss: 0.778130] [G loss: 1.614185] [D real: 0.688520] [D fake: 0.236985]
[Epoch 36/50] [Batch 599/938] [D loss: 0.860561] [G loss: 1.089725] [D real: 0.681044] [D fake: 0.259213]
[Epoch 36/50] [Batch 899/938] [D loss: 0.987945] [G loss: 1.252129] [D real: 0.616911] [D fake: 0.294552]
[Epoch 37/50] [Batch 299/938] [D loss: 0.959838] [G loss: 1.899603] [D real: 0.672147] [D fake: 0.321734]
[Epoch 37/50] [Batch 599/938] [D loss: 1.145939] [G loss: 0.881653] [D real: 0.521278] [D fake: 0.169943]
[Epoch 37/50] [Batch 899/938] [D loss: 1.081005] [G loss: 1.120957] [D real: 0.680088] [D fake: 0.382852]
[Epoch 38/50] [Batch 299/938] [D loss: 0.817964] [G loss: 2.585278] [D real: 0.825992] [D fake: 0.400205]
[Epoch 38/50] [Batch 599/938] [D loss: 0.842400] [G loss: 1.488824] [D real: 0.719073] [D fake: 0.331329]
[Epoch 38/50] [Batch 899/938] [D loss: 0.966889] [G loss: 0.889513] [D real: 0.591993] [D fake: 0.217513]
[Epoch 39/50] [Batch 299/938] [D loss: 0.913460] [G loss: 2.298306] [D real: 0.788721] [D fake: 0.396490]
[Epoch 39/50] [Batch 599/938] [D loss: 1.126092] [G loss: 2.391943] [D real: 0.788027] [D fake: 0.508409]
[Epoch 39/50] [Batch 899/938] [D loss: 0.967621] [G loss: 1.765748] [D real: 0.776652] [D fake: 0.437789]
[Epoch 40/50] [Batch 299/938] [D loss: 0.911972] [G loss: 1.264209] [D real: 0.615513] [D fake: 0.190590]
[Epoch 40/50] [Batch 599/938] [D loss: 0.808387] [G loss: 1.492971] [D real: 0.700173] [D fake: 0.242417]
[Epoch 40/50] [Batch 899/938] [D loss: 0.945543] [G loss: 1.199135] [D real: 0.641663] [D fake: 0.264667]
[Epoch 41/50] [Batch 299/938] [D loss: 0.767059] [G loss: 1.265899] [D real: 0.729328] [D fake: 0.256548]
[Epoch 41/50] [Batch 599/938] [D loss: 0.768910] [G loss: 1.395854] [D real: 0.698200] [D fake: 0.252217]
[Epoch 41/50] [Batch 899/938] [D loss: 0.828038] [G loss: 1.425793] [D real: 0.676624] [D fake: 0.215663]
[Epoch 42/50] [Batch 299/938] [D loss: 0.786111] [G loss: 1.149347] [D real: 0.679647] [D fake: 0.240033]
[Epoch 42/50] [Batch 599/938] [D loss: 0.926057] [G loss: 1.093621] [D real: 0.611771] [D fake: 0.213749]
[Epoch 42/50] [Batch 899/938] [D loss: 1.051074] [G loss: 1.875097] [D real: 0.780975] [D fake: 0.452756]
[Epoch 43/50] [Batch 299/938] [D loss: 0.949325] [G loss: 1.513711] [D real: 0.725759] [D fake: 0.343389]
[Epoch 43/50] [Batch 599/938] [D loss: 1.072763] [G loss: 1.187461] [D real: 0.599758] [D fake: 0.256029]
[Epoch 43/50] [Batch 899/938] [D loss: 0.936638] [G loss: 1.526531] [D real: 0.653215] [D fake: 0.225740]
[Epoch 44/50] [Batch 299/938] [D loss: 0.783098] [G loss: 1.213412] [D real: 0.683519] [D fake: 0.254234]
[Epoch 44/50] [Batch 599/938] [D loss: 1.032195] [G loss: 1.537429] [D real: 0.784595] [D fake: 0.430993]
[Epoch 44/50] [Batch 899/938] [D loss: 0.841602] [G loss: 1.410475] [D real: 0.681418] [D fake: 0.202322]
[Epoch 45/50] [Batch 299/938] [D loss: 0.913109] [G loss: 1.560791] [D real: 0.653888] [D fake: 0.276527]
[Epoch 45/50] [Batch 599/938] [D loss: 0.827896] [G loss: 1.580814] [D real: 0.702827] [D fake: 0.265357]
[Epoch 45/50] [Batch 899/938] [D loss: 0.721954] [G loss: 1.603410] [D real: 0.723822] [D fake: 0.233086]
[Epoch 46/50] [Batch 299/938] [D loss: 0.811522] [G loss: 1.570267] [D real: 0.741978] [D fake: 0.327063]
[Epoch 46/50] [Batch 599/938] [D loss: 1.024942] [G loss: 1.180711] [D real: 0.570806] [D fake: 0.164485]
[Epoch 46/50] [Batch 899/938] [D loss: 0.845878] [G loss: 1.584793] [D real: 0.627794] [D fake: 0.176946]
[Epoch 47/50] [Batch 299/938] [D loss: 1.037055] [G loss: 1.048466] [D real: 0.606069] [D fake: 0.248434]
[Epoch 47/50] [Batch 599/938] [D loss: 0.873520] [G loss: 1.529568] [D real: 0.665836] [D fake: 0.266380]
[Epoch 47/50] [Batch 899/938] [D loss: 0.909397] [G loss: 0.988783] [D real: 0.699651] [D fake: 0.329167]
[Epoch 48/50] [Batch 299/938] [D loss: 0.762955] [G loss: 1.757964] [D real: 0.691470] [D fake: 0.175175]
[Epoch 48/50] [Batch 599/938] [D loss: 0.751731] [G loss: 1.582819] [D real: 0.695594] [D fake: 0.232840]
[Epoch 48/50] [Batch 899/938] [D loss: 0.920607] [G loss: 2.062231] [D real: 0.841761] [D fake: 0.449201]
[Epoch 49/50] [Batch 299/938] [D loss: 0.803595] [G loss: 1.586111] [D real: 0.708683] [D fake: 0.264936]
[Epoch 49/50] [Batch 599/938] [D loss: 0.767607] [G loss: 1.513981] [D real: 0.702909] [D fake: 0.244586]
[Epoch 49/50] [Batch 899/938] [D loss: 1.228767] [G loss: 0.696118] [D real: 0.512739] [D fake: 0.159993]
3、保存模型¶
## 保存模型
torch.save(generator.state_dict(), './data/save/generator.pth')
torch.save(discriminator.state_dict(), './data/save/discriminator.pth')