LEP+低秩+神经网络去噪

from __future__ import print_function
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline
import scipy.misc
import os
import numpy as np

from models.resnet import ResNet
from models.unet import UNet
from models.skip import skip
import torch
import torch.optim

from utils.inpainting_utils import *
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark =True
dtype = torch.cuda.FloatTensor

PLOT = True
imsize = -1
dim_div_by = 64
NET_TYPE = 'skip_depth6'

iteation_LEP = '/home/hxj/桌面/PG/test/iteation+LEP/'
LEP = '/home/hxj/桌面/PG/test/LEP-only/'
ORI = '/home/hxj/gluon-tutorials/GAN/MultiPIE/YaleB_test_crop_gray/'
img_name = 'yaleB38_P00A-130E+20.png'
real_face_name='data/face/reSVD10.png'

pad = 'reflection' # 'zero'
OPT_OVER = 'net'
OPTIMIZER = 'adam'
INPUT = 'noise'
input_depth = 32
#input_depth = 4
num_iter = 600
param_noise = False
figsize = 5 
reg_noise_std = 0.03
LR = 0.01
mse = torch.nn.MSELoss().type(dtype)
#i = 0
def closure():
    #global i
    
    if param_noise:
        for n in [x for x in net.parameters() if len(x.size()) == 4]:
            n = n + n.detach().clone().normal_() * n.std() / 50
    
    net_input = net_input_saved
    if reg_noise_std > 0:
        net_input = net_input_saved + (noise.normal_() * reg_noise_std)
        
        
    out = net(net_input)
   
    #total_loss = mse(out * mask_var, img_var * mask_var)
    #total_loss = mse(out, img_var)
    total_loss = mse(out,itLEP_var) + mse(out,ORI_var)*0.1+ mse(out,LEP_var)*0.2 + mse(out,RF_var)*0.5
    total_loss.backward()
        
    print ('Iteration %s     Loss %f' % (img_name, total_loss.item()), '\r', end='')
    #if  PLOT and i % show_every == 0:
        #out_np = torch_to_np(out)
        #img_save =(np.clip(out_np, 0, 1))[0]
        #scipy.misc.toimage(img_save, cmin=0.0, cmax=1.0).save('result/'+str(i)+'_'+img_name)
        #plot_image_grid([np.clip(out_np, 0, 1)], factor=figsize, nrow=1)
        #plt.imshow(img_save)
        #plt.axis('off')
        #plt.savefig('result/'+str(i)+'_'+img_name,dpi=128*128)
        #plt.show()
         
       
    #i += 1

    return total_loss
RF_pil, RF_np = get_image(real_face_name, imsize)
RF_var = np_to_torch(RF_np).type(dtype)

files = os.listdir(iteation_LEP)
for img_name in files:
    itLEP_pil, itLEP_np = get_image(iteation_LEP+img_name, imsize)
    LEP_pil, LEP_np = get_image(LEP+img_name, imsize)
    ORI_pil, ORI_np = get_image(ORI+img_name, imsize)
    
    itLEP_var = np_to_torch(itLEP_np).type(dtype)
    LEP_var = np_to_torch(LEP_np).type(dtype)
    ORI_var = np_to_torch(ORI_np).type(dtype)
    
    net = skip(input_depth, itLEP_np.shape[0], 
           num_channels_down = [128] * 5,
           num_channels_up =   [128] * 5,
           num_channels_skip =    [128] * 5,
           filter_size_up = 3, filter_size_down = 3,
           upsample_mode='nearest', filter_skip_size=1,
           need_sigmoid=True, need_bias=True, pad=pad, act_fun='LeakyReLU').type(dtype)
    
    net_input = get_noise(input_depth, INPUT, itLEP_np.shape[1:]).type(dtype)
    # net_input[0,0,:] = itLEP_var
    # net_input[0,1,:] = LEP_var
    # net_input[0,2,:] = ORI_var
    # net_input[0,3,:] = RF_var
    #net_input = np_to_torch(RF_np).type(dtype)
    
    net_input_saved = net_input.detach().clone()
    noise = net_input.detach().clone()
    p = get_params(OPT_OVER, net, net_input)
    optimize(OPTIMIZER, p, closure, LR, num_iter)

    
    out_np = torch_to_np(net(net_input))
    img_save =(np.clip(out_np, 0, 1))[0]
    scipy.misc.toimage(img_save, cmin=0.0, cmax=1.0).save('result/noise_input/0.01/'+img_name)
    

 

posted @ 2019-05-06 09:14  白菜hxj  阅读(757)  评论(0编辑  收藏  举报