加载pytorch格式的dcgan已训练网络并测试

建立的dcgan网络通过之后可以得到生成网络netG.pkl文件和鉴别网络netD.pkl文件,加载这些网络输入参数即可得到结果。这里显示了生成网络的加载及测试。同时也调用了网络结构显示的库,以pdf的形式显示所加载的网络的具体结构。

'''
Created on 2020年10月27日

@author: afeng
'''

import torch
import torchvision.utils as vutils 
import numpy as np
from matplotlib import pyplot as plt
from torchviz import make_dot
from dcgan_facies_model import Generator
from tensorflow.python.keras.layers import noise

def loadModel(fileName):   
    trained_netG=torch.load(fileName)
    #print(trained_netG)
    return trained_netG

def testTrainNetG(filename):    
    loaded_netG = loadModel(filename)
    #print(loaded_netG)

    b_size=64
    nz=100
    device=torch.device('cuda:0')
    noise = torch.randn(b_size, nz, 1, 1, device=device)
    pred = loaded_netG(noise)
    print(pred.shape)

    #plt.imshow(np.transpose(vutils.make_grid(pred.to(device)[:64], padding=5, normalize=True).cpu().detach().numpy(),(1,2,0)))
    plt.imshow(np.transpose(vutils.make_grid(pred[0].to(device)[:64], padding=5, normalize=True).cpu().detach().numpy(),(1,2,0)))
    plt.show()    
    #saveNet2PDFFile(loaded_netG, noise)    

def saveNet2PDFFile(loaded_netG, noise):
    #plot the net model as pdf file 
    net_plot = make_dot(loaded_netG(noise), params=dict(loaded_netG.named_parameters()))
    #net_plot = make_dot(loaded_netG(noise))
    net_plot.view("loaded_net")


if __name__ == '__main__':
    filename='trained_netG.pkl'
    testTrainNetG(filename)
    pass

 下图展示了所加载网络的结构

netG网络的输出结果不再展示,和mnist的手写图像差不多。

 

posted @ 2022-08-21 10:12  Oliver2022  阅读(33)  评论(0编辑  收藏  举报