LEP+低秩+神经网络去噪
from __future__ import print_function import matplotlib import matplotlib.pyplot as plt %matplotlib inline import scipy.misc import os import numpy as np from models.resnet import ResNet from models.unet import UNet from models.skip import skip import torch import torch.optim from utils.inpainting_utils import * torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark =True dtype = torch.cuda.FloatTensor PLOT = True imsize = -1 dim_div_by = 64 NET_TYPE = 'skip_depth6' iteation_LEP = '/home/hxj/桌面/PG/test/iteation+LEP/' LEP = '/home/hxj/桌面/PG/test/LEP-only/' ORI = '/home/hxj/gluon-tutorials/GAN/MultiPIE/YaleB_test_crop_gray/' img_name = 'yaleB38_P00A-130E+20.png' real_face_name='data/face/reSVD10.png' pad = 'reflection' # 'zero' OPT_OVER = 'net' OPTIMIZER = 'adam' INPUT = 'noise' input_depth = 32 #input_depth = 4 num_iter = 600 param_noise = False figsize = 5 reg_noise_std = 0.03 LR = 0.01 mse = torch.nn.MSELoss().type(dtype)
#i = 0 def closure(): #global i if param_noise: for n in [x for x in net.parameters() if len(x.size()) == 4]: n = n + n.detach().clone().normal_() * n.std() / 50 net_input = net_input_saved if reg_noise_std > 0: net_input = net_input_saved + (noise.normal_() * reg_noise_std) out = net(net_input) #total_loss = mse(out * mask_var, img_var * mask_var) #total_loss = mse(out, img_var) total_loss = mse(out,itLEP_var) + mse(out,ORI_var)*0.1+ mse(out,LEP_var)*0.2 + mse(out,RF_var)*0.5 total_loss.backward() print ('Iteration %s Loss %f' % (img_name, total_loss.item()), '\r', end='') #if PLOT and i % show_every == 0: #out_np = torch_to_np(out) #img_save =(np.clip(out_np, 0, 1))[0] #scipy.misc.toimage(img_save, cmin=0.0, cmax=1.0).save('result/'+str(i)+'_'+img_name) #plot_image_grid([np.clip(out_np, 0, 1)], factor=figsize, nrow=1) #plt.imshow(img_save) #plt.axis('off') #plt.savefig('result/'+str(i)+'_'+img_name,dpi=128*128) #plt.show() #i += 1 return total_loss
RF_pil, RF_np = get_image(real_face_name, imsize) RF_var = np_to_torch(RF_np).type(dtype) files = os.listdir(iteation_LEP) for img_name in files: itLEP_pil, itLEP_np = get_image(iteation_LEP+img_name, imsize) LEP_pil, LEP_np = get_image(LEP+img_name, imsize) ORI_pil, ORI_np = get_image(ORI+img_name, imsize) itLEP_var = np_to_torch(itLEP_np).type(dtype) LEP_var = np_to_torch(LEP_np).type(dtype) ORI_var = np_to_torch(ORI_np).type(dtype) net = skip(input_depth, itLEP_np.shape[0], num_channels_down = [128] * 5, num_channels_up = [128] * 5, num_channels_skip = [128] * 5, filter_size_up = 3, filter_size_down = 3, upsample_mode='nearest', filter_skip_size=1, need_sigmoid=True, need_bias=True, pad=pad, act_fun='LeakyReLU').type(dtype) net_input = get_noise(input_depth, INPUT, itLEP_np.shape[1:]).type(dtype) # net_input[0,0,:] = itLEP_var # net_input[0,1,:] = LEP_var # net_input[0,2,:] = ORI_var # net_input[0,3,:] = RF_var #net_input = np_to_torch(RF_np).type(dtype) net_input_saved = net_input.detach().clone() noise = net_input.detach().clone() p = get_params(OPT_OVER, net, net_input) optimize(OPTIMIZER, p, closure, LR, num_iter) out_np = torch_to_np(net(net_input)) img_save =(np.clip(out_np, 0, 1))[0] scipy.misc.toimage(img_save, cmin=0.0, cmax=1.0).save('result/noise_input/0.01/'+img_name)