用pytorch实现对抗生成网络
最近在学习深度学习编程,采用的深度学习框架是pytorch,看的书主要是陈云编著的《深度学习框架PyTorch入门与实践》、廖星宇编著的《深度学习入门之PyTorch》、肖志清的《神经网络与PyTorch实践》,都是入门的学习材料,适合初学者。
通过近1个多月的学习,基本算是入门了,后面将深度学习与实践。这里分享一个《神经网络与PyTorch实践》中对抗生成网络的例子。它是用对抗生成网络的方法,训练CIFAR-10的数据集,训练模型。
生成网络gnet将大小为(64,11)的潜在张量转化为大小为(3,32,32)的假数据;鉴别网络dnet将大小为(3,32,32)的数据转化为大小为
(1,1,1)的对数赔率张量。下面是整个模型的python代码,包括(1)数据加载,(2)模型搭建,(3)模型训练与模型测试。
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.optim
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10,CIFAR100
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torchviz import make_dot
dataset = CIFAR100(root='./data',
download=True,
transform= transforms.ToTensor())
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
#check the data
#for batch_idx, data in enumerate(dataloader):
# real_images, _ = data
# print('real_images size = {}'.format(real_images.size()))
# batch_size = real_images.size(0)
# print('#{} has {} images.'.format(batch_idx, batch_size))
# if batch_idx %100 ==0:
# path = './data/CIFAR10_shuffled_batch{:03d}.png'.format(batch_idx)
# save_image(real_images, path, normalize=True)
#construct the generator and discrimiter network
latent_size=64 #潜在大小
n_channel=3 #输出通道数
n_g_feature=64 #生成网络隐藏层大小
#construct the generator
gnet= nn.Sequential(
#输入大小 == (64, 1, 1)
nn.ConvTranspose2d(latent_size, 4 * n_g_feature, kernel_size=4, bias=False),
nn.BatchNorm2d(4*n_g_feature),
nn.ReLU(),
#大小 = (256,4,4)
nn.ConvTranspose2d(4*n_g_feature, 2 * n_g_feature, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(2*n_g_feature),
nn.ReLU(),
#大小 = (128, 8,8)
nn.ConvTranspose2d(2*n_g_feature, n_g_feature, kernel_size=4, stride=2, padding=1, bias= False),
nn.BatchNorm2d(n_g_feature),
nn.ReLU(),
#大小 = (64,16,16)
nn.ConvTranspose2d(n_g_feature, n_channel, kernel_size=4, stride=2, padding=1),
nn.Sigmoid(),
#图片大小 = (3, 32, 32)
)
#define the instance of GeneratorNet
print(gnet)
if torch.cuda.is_available():
gnet.to(torch.device('cuda:0'))
#construct the discrimator
n_d_feature = 64 #鉴别网络隐藏层大小
dnet = nn.Sequential(
#图片大小 = (3,32,32)
nn.Conv2d(n_channel, n_d_feature, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2),
#大小 = (63,16,16)
nn.Conv2d(n_d_feature, 2*n_d_feature, kernel_size=4, stride=2, padding=1, bias= False),
nn.BatchNorm2d(2*n_d_feature),
nn.LeakyReLU(0.2),
#大小 = (128, 8,8)
nn.Conv2d(2*n_d_feature, 4*n_d_feature, kernel_size=4, stride=2, padding=1, bias= False),
nn.BatchNorm2d(4*n_d_feature),
nn.LeakyReLU(0.2),
#大小 = (256,4,4)
nn.Conv2d(4*n_d_feature, 1, kernel_size=4),
#对数赔率张量大小=(1,1,1)
#nn.Sigmoid()
)
print(dnet)
if torch.cuda.is_available():
dnet.to(torch.device('cuda:0'))
#initialization for gnet and dnet
def weights_init(m):
if type(m) in [nn.ConvTranspose2d, nn.Conv2d]:
init.xavier_normal_(m.weight)
elif type(m) == nn.BatchNorm2d:
init.normal_(m.weight, 1.0, 0.02)
init.constant_(m.bias, 0)
gnet.apply(weights_init)
dnet.apply(weights_init)
#网络的训练和使用
#要构造一个损失函数并对它进行优化
#定义损失
criterion = nn.BCEWithLogitsLoss()
#定义优化器
goptimizer = torch.optim.Adam(gnet.parameters(), lr=0.0002, betas=(0.5, 0.999))
doptimizer = torch.optim.Adam(dnet.parameters(), lr=0.0002, betas=(0.5, 0.999))
#用于测试的噪声,用来查看相同的潜在张量在训练过程中生成图片的变换
batch_size=64
fixed_noises = torch.randn(batch_size, latent_size, 1,1)
#save the net to file for check
y=gnet(fixed_noises)
vise_graph = make_dot(y, params=dict(gnet.named_parameters()))
vise_graph.view(filename='gnet')
y=dnet(y)
vise_graph = make_dot(y)
vise_graph.view(filename='dnet')
#训练过程
epoch_num=10
for epoch in range(epoch_num):
for batch_idx, data in enumerate(dataloader):
#载入本批次数据
real_images,_ = data
batch_size = real_images.size(0)
#训练鉴别网络
labels = torch.ones(batch_size) #设置真实数据对应标签为1
preds = dnet(real_images) #对真实数据进行判别
outputs = preds.reshape(-1)
dloss_real = criterion(outputs, labels) #真实数据的鉴别损失
dmean_real = outputs.sigmoid().mean() #计算鉴别器将多少比例的真实数据判定为真,仅用于输出显示
noises = torch.randn(batch_size, latent_size, 1,1) #潜在噪声
fake_images = gnet(noises) #生成假数据
labels = torch.zeros(batch_size) #假数据对应标签为0
fake = fake_images.detach() #是的梯度的计算不回溯到生成网络,可用于加快训练速度。删去此步,结果不变
preds = dnet(fake)
outputs = preds.view(-1)
dloss_fake = criterion(outputs, labels) #假数据的鉴别损失
dmean_fake = outputs.sigmoid().mean() #计算鉴别器将多少比例的假数据判定为真,仅用于输出显示
dloss = dloss_real+dloss_fake
dnet.zero_grad()
dloss.backward()
doptimizer.step()
#训练生成网络
labels = torch.ones(batch_size) #生成网络希望所有生成的数据都是被认为时真的
preds = dnet(fake_images) #让假数据通过假别网络
outputs = preds.view(-1)
gloss = criterion(outputs, labels) #从真数据看到的损失
gmean_fake = outputs.sigmoid().mean() #计算鉴别器将多少比例的假数据判断为真,仅用于输出显示
gnet.zero_grad()
gloss.backward()
goptimizer.step()
#输出本步训练结果
print('[{}/{}]'.format(epoch, epoch_num)+
'[{}/{}]'.format(batch_idx, len(dataloader))+
'鉴别网络损失:{:g} 生成网络损失:{:g}'.format(dloss, gloss)+
'真实数据判真比例:{:g} 假数据判真比例:{:g}/{:g}'.format(dmean_real, dmean_fake, gmean_fake))
if batch_idx %100 == 0:
fake = gnet(fixed_noises) #由固定潜在征粮生成假数据
save_image(fake, './data/images_epoch{:02d}_batch{:03d}.png'.format(epoch, batch_idx)) #保存假数据
#保存训练的网络
torch.save(gnet, 'gnet.pkl')
torch.save(dnet, 'dnet.pkl')
结果如下