加载pytorch格式的dcgan已训练网络并测试
建立的dcgan网络通过之后可以得到生成网络netG.pkl文件和鉴别网络netD.pkl文件,加载这些网络输入参数即可得到结果。这里显示了生成网络的加载及测试。同时也调用了网络结构显示的库,以pdf的形式显示所加载的网络的具体结构。
'''
Created on 2020年10月27日
@author: afeng
'''
import torch
import torchvision.utils as vutils
import numpy as np
from matplotlib import pyplot as plt
from torchviz import make_dot
from dcgan_facies_model import Generator
from tensorflow.python.keras.layers import noise
def loadModel(fileName):
trained_netG=torch.load(fileName)
#print(trained_netG)
return trained_netG
def testTrainNetG(filename):
loaded_netG = loadModel(filename)
#print(loaded_netG)
b_size=64
nz=100
device=torch.device('cuda:0')
noise = torch.randn(b_size, nz, 1, 1, device=device)
pred = loaded_netG(noise)
print(pred.shape)
#plt.imshow(np.transpose(vutils.make_grid(pred.to(device)[:64], padding=5, normalize=True).cpu().detach().numpy(),(1,2,0)))
plt.imshow(np.transpose(vutils.make_grid(pred[0].to(device)[:64], padding=5, normalize=True).cpu().detach().numpy(),(1,2,0)))
plt.show()
#saveNet2PDFFile(loaded_netG, noise)
def saveNet2PDFFile(loaded_netG, noise):
#plot the net model as pdf file
net_plot = make_dot(loaded_netG(noise), params=dict(loaded_netG.named_parameters()))
#net_plot = make_dot(loaded_netG(noise))
net_plot.view("loaded_net")
if __name__ == '__main__':
filename='trained_netG.pkl'
testTrainNetG(filename)
pass
下图展示了所加载网络的结构
netG网络的输出结果不再展示,和mnist的手写图像差不多。