Sym-GAN

import sys; 
sys.path.append("/home/hxj/anaconda3/lib/python3.6/site-packages")
from __future__ import print_function
import os
import matplotlib as mpl
import tarfile
import matplotlib.image as mpimg
from matplotlib import pyplot as plt
import cv2
import mxnet as mx
from mxnet import gluon
from mxnet import ndarray as nd
from mxnet.gluon import nn, utils
from mxnet.gluon.nn import Dense, Activation, Conv2D, Conv2DTranspose, \
    BatchNorm, LeakyReLU, Flatten, HybridSequential, HybridBlock, Dropout
from mxnet import autograd
import numpy as np

epochs = 500
batch_size = 10

use_gpu = True
ctx = mx.gpu() if use_gpu else mx.cpu()

lr = 0.0002
beta1 = 0.5
#lambda1 = 100
lambda1 = 10

pool_size = 50
img_horizon = mx.image.HorizontalFlipAug(1)
def load_retinex(batch_size):
    img_in_list = []
    img_out_list = []

    """
    path='CAS/Lighting_aligned_128'
    ground_path = 'CAS/Lighting_aligned_128_retinex_to_color'
    
    for path, _, fnames in os.walk(path):
        for fname in fnames:
            if not fname.endswith('.png'):
                continue
                      
            lingting_img = os.path.join(path, fname)
            ground_img = os.path.join(ground_path,fname)
                    
            #补充水平翻转和光照增加或者减少50%
            img_arr_fname = mx.image.imread(lingting_img).astype(np.float32)/127.5 - 1
            img_arr_fname_t = img_horizon(img_arr_fname)
            img_arr_gnema = mx.image.imread(ground_img).astype(np.float32)/127.5 - 1
            img_arr_gnema_t = img_horizon(img_arr_gnema)
            
            img_arr_fname = cv2.cvtColor(img_arr_fname.asnumpy(), cv2.COLOR_RGB2LAB)
            img_arr_fname_t = cv2.cvtColor(img_arr_fname_t.asnumpy(), cv2.COLOR_RGB2LAB)
            img_arr_gnema = cv2.cvtColor(img_arr_gnema.asnumpy(), cv2.COLOR_RGB2LAB)
            img_arr_gnema_t = cv2.cvtColor(img_arr_gnema_t.asnumpy(), cv2.COLOR_RGB2LAB)
            
            
            img_arr_in, img_arr_out = [img_arr_fname[:,:,0].reshape((1,) + img_arr_in.shape),
                                       img_arr_out.reshape((1,) + img_arr_out.shape)]
            img_in_list.append(img_arr_in)
            img_out_list.append(img_arr_out)
            
            img_arr_in_t, img_arr_out_t = [nd.transpose(img_arr_fname_t, (2,0,1)),
                                           nd.transpose(img_arr_gnema_t, (2,0,1))]
            img_arr_in_t, img_arr_out_t = [img_arr_in_t.reshape((1,) + img_arr_in_t.shape),
                                           img_arr_out_t.reshape((1,) + img_arr_out_t.shape)]
            img_in_list.append(img_arr_in_t)
            img_out_list.append(img_arr_out_t)
    """       
    mulpath_lighting = 'MultiPIE/MultiPIE_Lighting/'
    mulpaht_ground = 'MultiPIE/MultiPIE_Lighting/'
    for path, _, fnames in os.walk(mulpath_lighting):
        for fname in fnames:
            num = fname[14:16]
            if num !='07':
                lingting_img = os.path.join(mulpath_lighting, fname)
                ground_img = os.path.join(mulpaht_ground,fname[:14]+'07.png')
                img_arr_fname = mx.image.imread(lingting_img).astype(np.float32)/127.5 - 1
                img_arr_gnema = mx.image.imread(ground_img).astype(np.float32)/127.5 - 1
                          
            
            #img_arr_fname = mx.image.imresize(img_arr_fname,256,256)
            #img_arr_gnema = mx.image.imresize(img_arr_gnema,256,256)
            #补充水平翻转和光照增加或者减少50%
            #img_arr_fname_b = img_bright(img_arr_fname)
                
                img_arr_fname_t = img_horizon(img_arr_fname)
                img_arr_gnema_t = img_horizon(img_arr_gnema)
              #lighting image 共4个,normal ground truth共2个          
               
                img_arr_in, img_arr_out = [nd.transpose(img_arr_fname, (2,0,1)),
                                           nd.transpose(img_arr_gnema, (2,0,1))]
                img_arr_in, img_arr_out = [img_arr_in.reshape((1,) + img_arr_in.shape),
                                           img_arr_out.reshape((1,) + img_arr_out.shape)]
                img_in_list.append(img_arr_in)
                img_out_list.append(img_arr_out)
            
                img_arr_in_t, img_arr_out_t = [nd.transpose(img_arr_fname_t, (2,0,1)),
                                               nd.transpose(img_arr_gnema_t, (2,0,1))]
                img_arr_in_t, img_arr_out_t = [img_arr_in_t.reshape((1,) + img_arr_in_t.shape),
                                               img_arr_out_t.reshape((1,) + img_arr_out_t.shape)]
                img_in_list.append(img_arr_in_t)
                img_out_list.append(img_arr_out_t)
                
            
    return mx.io.NDArrayIter(data=[nd.concat(*img_in_list,dim=0), nd.concat(*img_out_list,dim=0)],batch_size=batch_size)
    
img_wd = 256
img_ht = 256
train_img_path = '../data/edges2handbags/train_mini/'
val_img_path = '../data/edges2handbags/val/' 

def load_data(path, batch_size, is_reversed=False):
    img_in_list = []
    img_out_list = []
    for path, _, fnames in os.walk(path):
        for fname in fnames:
            if not fname.endswith('.jpg'):
                continue
            img = os.path.join(path, fname)
            img_arr = mx.image.imread(img).astype(np.float32)/127.5 - 1
            img_arr = mx.image.imresize(img_arr, img_wd * 2, img_ht)
            # Crop input and output images
            img_arr_in, img_arr_out = [mx.image.fixed_crop(img_arr, 0, 0, img_wd, img_ht),
                                       mx.image.fixed_crop(img_arr, img_wd, 0, img_wd, img_ht)]
            img_arr_in, img_arr_out = [nd.transpose(img_arr_in, (2,0,1)),
                                       nd.transpose(img_arr_out, (2,0,1))]
            img_arr_in, img_arr_out = [img_arr_in.reshape((1,) + img_arr_in.shape),
                                       img_arr_out.reshape((1,) + img_arr_out.shape)]
            img_in_list.append(img_arr_out if is_reversed else img_arr_in)
            img_out_list.append(img_arr_in if is_reversed else img_arr_out)

    return mx.io.NDArrayIter(data=[nd.concat(*img_in_list, dim=0), nd.concat(*img_out_list, dim=0)],
                             batch_size=batch_size)


train_data = load_data(train_img_path, batch_size, is_reversed=False)
val_data = load_data(val_img_path, batch_size, is_reversed=False)
img_horizon = mx.image.HorizontalFlipAug(1)
def load_retinex(batch_size):
    img_in_list = []
    img_out_list = []
    
    path='CAS/Lighting_aligned_128'
    ground_path = 'CAS/Normal_aligned_128'
    img_in_list = []
    img_out_list = []
    """ 
    for path, _, fnames in os.walk(path):
        for fname in fnames:
            if not fname.endswith('.png'):
                continue
            
            temp_name = fname[0:9]+'_IEU+00_PM+00_EN_A0_D0_T0_BB_M0_R0_S0.png'
            ground_img = os.path.join(ground_path, temp_name)
            if not os.path.exists(ground_img):
                temp_name = fname[0:9]+'_IEU+00_PM+00_EN_A0_D0_T0_BB_M0_R1_S0.png'
                ground_img = os.path.join(ground_path, temp_name)
            if not os.path.exists(ground_img):
                continue
            lingting_img = os.path.join(path, fname)
                    
            #补充水平翻转和光照增加或者减少50%
            img_arr_fname = mx.image.imread(lingting_img).astype(np.float32)/127.5 - 1
            img_arr_fname_t = img_horizon(img_arr_fname)
                     
            img_arr_gnema = mx.image.imread(ground_img).astype(np.float32)/127.5 - 1
            img_arr_gnema_t = img_horizon(img_arr_gnema)
              
            img_arr_in, img_arr_out = [nd.transpose(img_arr_fname, (2,0,1)),
                                    nd.transpose(img_arr_gnema, (2,0,1))]
            img_arr_in, img_arr_out = [img_arr_in.reshape((1,) + img_arr_in.shape),
                                    img_arr_out.reshape((1,) + img_arr_out.shape)]
            img_in_list.append(img_arr_in)
            img_out_list.append(img_arr_out)
                     
            img_arr_in_t, img_arr_out_t = [nd.transpose(img_arr_fname_t, (2,0,1)),
                                            nd.transpose(img_arr_gnema_t, (2,0,1))]
            img_arr_in_t, img_arr_out_t = [img_arr_in_t.reshape((1,) + img_arr_in_t.shape),
                                         img_arr_out_t.reshape((1,) + img_arr_out_t.shape)]
            img_in_list.append(img_arr_in_t)
            img_out_list.append(img_arr_out_t)
            
    """       
    mulpath_lighting = 'MultiPIE/MultiPIE_Lighting_128/'
    mulpaht_ground = 'MultiPIE/MultiPIE_Lighting_128/'
    for path, _, fnames in os.walk(mulpath_lighting):
        for fname in fnames:
            num = fname[14:16]
            if num !='07':
                lingting_img = os.path.join(mulpath_lighting, fname)
                ground_img = os.path.join(mulpaht_ground,fname[:14]+'07.png')
                img_arr_fname = mx.image.imread(lingting_img).astype(np.float32)/127.5 - 1
                img_arr_gnema = mx.image.imread(ground_img).astype(np.float32)/127.5 - 1
            
            #img_arr_fname = mx.image.imresize(img_arr_fname,256,256)
            #img_arr_gnema = mx.image.imresize(img_arr_gnema,256,256)
            #补充水平翻转和光照增加或者减少50%
            #img_arr_fname_b = img_bright(img_arr_fname)
                
                img_arr_fname_t = img_horizon(img_arr_fname)
                img_arr_gnema_t = img_horizon(img_arr_gnema)
              #lighting image 共4个,normal ground truth共2个          
               
                img_arr_in, img_arr_out = [nd.transpose(img_arr_fname, (2,0,1)),
                                           nd.transpose(img_arr_gnema, (2,0,1))]
                img_arr_in, img_arr_out = [img_arr_in.reshape((1,) + img_arr_in.shape),
                                           img_arr_out.reshape((1,) + img_arr_out.shape)]
                img_in_list.append(img_arr_in)
                img_out_list.append(img_arr_out)
            
                img_arr_in_t, img_arr_out_t = [nd.transpose(img_arr_fname_t, (2,0,1)),
                                               nd.transpose(img_arr_gnema_t, (2,0,1))]
                img_arr_in_t, img_arr_out_t = [img_arr_in_t.reshape((1,) + img_arr_in_t.shape),
                                               img_arr_out_t.reshape((1,) + img_arr_out_t.shape)]
                img_in_list.append(img_arr_in_t)
                img_out_list.append(img_arr_out_t)
                
       
    return mx.io.NDArrayIter(data=[nd.concat(*img_in_list,dim=0), nd.concat(*img_out_list,dim=0)],batch_size=batch_size)
    
def visualize(img_arr):
    plt.imshow(((img_arr.asnumpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8))
    plt.axis('off')
def preview_train_data(train_data):
    img_in_list, img_out_list = train_data.next().data
    for i in range(4):
        plt.subplot(2,4,i+1)
        visualize(img_in_list[i])
        plt.subplot(2,4,i+5)
        visualize(img_out_list[i])
    plt.show()


train_data = load_retinex(10)
preview_train_data(train_data)
# Define Unet generator skip block
class UnetSkipUnit(HybridBlock):
    def __init__(self, inner_channels, outer_channels, inner_block=None, innermost=False, outermost=False,
                 use_dropout=False, use_bias=False):
        super(UnetSkipUnit, self).__init__()

        with self.name_scope():
            self.outermost = outermost
            en_conv = Conv2D(channels=inner_channels, kernel_size=4, strides=2, padding=1,
                             in_channels=outer_channels, use_bias=use_bias)
            en_relu = LeakyReLU(alpha=0.2)
            en_norm = BatchNorm(momentum=0.1, in_channels=inner_channels)
            de_relu = Activation(activation='relu')
            de_norm = BatchNorm(momentum=0.1, in_channels=outer_channels)

            if innermost:
                de_conv = Conv2DTranspose(channels=outer_channels, kernel_size=4, strides=2, padding=1,
                                          in_channels=inner_channels, use_bias=use_bias)
                encoder = [en_relu, en_conv]
                decoder = [de_relu, de_conv, de_norm]
                model = encoder + decoder
            elif outermost:
                de_conv = Conv2DTranspose(channels=outer_channels, kernel_size=4, strides=2, padding=1,
                                          in_channels=inner_channels * 2)
                encoder = [en_conv]
                decoder = [de_relu, de_conv, Activation(activation='tanh')]
                model = encoder + [inner_block] + decoder
            else:
                de_conv = Conv2DTranspose(channels=outer_channels, kernel_size=4, strides=2, padding=1,
                                          in_channels=inner_channels * 2, use_bias=use_bias)
                encoder = [en_relu, en_conv, en_norm]
                decoder = [de_relu, de_conv, de_norm]
                model = encoder + [inner_block] + decoder
            if use_dropout:
                model += [Dropout(rate=0.5)]

            self.model = HybridSequential()
            with self.model.name_scope():
                for block in model:
                    self.model.add(block)

    def hybrid_forward(self, F, x):
        if self.outermost:
            return self.model(x)
        else:
            return F.concat(self.model(x), x, dim=1)

# Define Unet generator
class UnetGenerator(HybridBlock):
    def __init__(self, in_channels, num_downs, ngf=64, use_dropout=True):
        super(UnetGenerator, self).__init__()

        #Build unet generator structure
        unet = UnetSkipUnit(ngf * 8, ngf * 8, innermost=True)
        for _ in range(num_downs - 5):
            unet = UnetSkipUnit(ngf * 8, ngf * 8, unet, use_dropout=use_dropout)
        unet = UnetSkipUnit(ngf * 8, ngf * 4, unet)
        unet = UnetSkipUnit(ngf * 4, ngf * 2, unet)
        unet = UnetSkipUnit(ngf * 2, ngf * 1, unet)
        unet = UnetSkipUnit(ngf, in_channels, unet, outermost=True)

        with self.name_scope():
            self.model = unet

    def hybrid_forward(self, F, x):
        return self.model(x)

# Define the PatchGAN discriminator
class Discriminator(HybridBlock):
    def __init__(self, in_channels, ndf=64, n_layers=3, use_sigmoid=False, use_bias=False):
        super(Discriminator, self).__init__()

        with self.name_scope():
            self.model = HybridSequential()
            kernel_size = 4
            padding = int(np.ceil((kernel_size - 1)/2))
            self.model.add(Conv2D(channels=ndf, kernel_size=kernel_size, strides=2,
                                  padding=padding, in_channels=in_channels))
            self.model.add(LeakyReLU(alpha=0.2))

            nf_mult = 1
            for n in range(1, n_layers):
                nf_mult_prev = nf_mult
                nf_mult = min(2 ** n, 8)
                self.model.add(Conv2D(channels=ndf * nf_mult, kernel_size=kernel_size, strides=2,
                                      padding=padding, in_channels=ndf * nf_mult_prev,
                                      use_bias=use_bias))
                self.model.add(BatchNorm(momentum=0.1, in_channels=ndf * nf_mult))
                self.model.add(LeakyReLU(alpha=0.2))

            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n_layers, 8)
            self.model.add(Conv2D(channels=ndf * nf_mult, kernel_size=kernel_size, strides=1,
                                  padding=padding, in_channels=ndf * nf_mult_prev,
                                  use_bias=use_bias))
            self.model.add(BatchNorm(momentum=0.1, in_channels=ndf * nf_mult))
            self.model.add(LeakyReLU(alpha=0.2))
            self.model.add(Conv2D(channels=1, kernel_size=kernel_size, strides=1,
                                  padding=padding, in_channels=ndf * nf_mult))
            if use_sigmoid:
                self.model.add(Activation(activation='sigmoid'))

    def hybrid_forward(self, F, x):
        out = self.model(x)
        #print(out)
        return out
def param_init(param):
    if param.name.find('conv') != -1:
        if param.name.find('weight') != -1:
            param.initialize(init=mx.init.Normal(0.02), ctx=ctx)
            
        else:
            param.initialize(init=mx.init.Zero(), ctx=ctx)
    elif param.name.find('batchnorm') != -1:
        param.initialize(init=mx.init.Zero(), ctx=ctx)
        # Initialize gamma from normal distribution with mean 1 and std 0.02
        if param.name.find('gamma') != -1:
            param.set_data(nd.random_normal(1, 0.02, param.data().shape))

def network_init(net):
    with net.name_scope():
        for param in net.collect_params().values():
            param_init(param)

def set_network():
    # Pixel2pixel networks
    netG1 = UnetGenerator(in_channels=3, num_downs=6)
    netD1 = Discriminator(in_channels=6)
    netG2 = UnetGenerator(in_channels=3, num_downs=6)
    netD2 = Discriminator(in_channels=6)

    # Initialize parameters
    network_init(netG1)
    network_init(netD1)
    network_init(netG2)
    network_init(netD2)

    # trainer for the generator and the discriminator
    trainerG1 = gluon.Trainer(netG1.collect_params(), 'adam', {'learning_rate': lr, 'beta1': beta1})
    trainerD1 = gluon.Trainer(netD1.collect_params(), 'adam', {'learning_rate': lr, 'beta1': beta1})
    
    trainerG2 = gluon.Trainer(netG2.collect_params(), 'adam', {'learning_rate': lr, 'beta1': beta1})
    trainerD2 = gluon.Trainer(netD2.collect_params(), 'adam', {'learning_rate': lr, 'beta1': beta1})
    return netG1, netD1, trainerG1, trainerD1, netG2, netD2, trainerG2, trainerD2

# Loss
#GAN_loss = gluon.loss.SigmoidBinaryCrossEntropyLoss()
GAN_loss = gluon.loss.L2Loss()
L1_loss = gluon.loss.L1Loss()
L2_loss = gluon.loss.L2Loss()

netG1, netD1, trainerG1, trainerD1, netG2, netD2, trainerG2, trainerD2 = set_network()
class ImagePool():
    def __init__(self, pool_size):
        self.pool_size = pool_size
        if self.pool_size > 0:
            self.num_imgs = 0
            self.images = []

    def query(self, images):
        if self.pool_size == 0:
            return images
        ret_imgs = []
        for i in range(images.shape[0]):
            image = nd.expand_dims(images[i], axis=0)
            if self.num_imgs < self.pool_size:
                self.num_imgs = self.num_imgs + 1
                self.images.append(image)
                ret_imgs.append(image)
            else:
                p = nd.random_uniform(0, 1, shape=(1,)).asscalar()
                if p > 0.5:
                    random_id = nd.random_uniform(0, self.pool_size - 1, shape=(1,)).astype(np.uint8).asscalar()
                    tmp = self.images[random_id].copy()
                    self.images[random_id] = image
                    ret_imgs.append(tmp)
                else:
                    ret_imgs.append(image)
        ret_imgs = nd.concat(*ret_imgs, dim=0)
        return ret_imgs

#这是retinex使用的代码

def singleScaleRetinex(img, sigma):
    retinex = np.log10(img) - np.log10(cv2.GaussianBlur(img, (0, 0), sigma))
    return retinex

def multiScaleRetinex(img, sigma_list):
    retinex = np.zeros_like(img)
    for sigma in sigma_list:
        retinex += singleScaleRetinex(img, sigma)
    retinex = retinex / len(sigma_list)
    return retinex

def colorRestoration(img, alpha, beta):
    img_sum = np.sum(img, axis=2, keepdims=True)
    color_restoration = beta * (np.log10(alpha * img) - np.log10(img_sum))
    return color_restoration

def simplestColorBalance(img, low_clip, high_clip):    
    total = img.shape[0] * img.shape[1]
    for i in range(img.shape[2]):
        unique, counts = np.unique(img[:, :, i], return_counts=True)
        current = 0
        for u, c in zip(unique, counts):            
            if float(current) / total < low_clip:
                low_val = u
            if float(current) / total < high_clip:
                high_val = u
            current += c
        img[:, :, i] = np.maximum(np.minimum(img[:, :, i], high_val), low_val)
    return img    

def MSRCR(img, sigma_list, G, b, alpha, beta, low_clip, high_clip):
    img = np.float64(img) + 1.0
    img_retinex = multiScaleRetinex(img, sigma_list)    
    img_color = colorRestoration(img, alpha, beta)    
    img_msrcr = G * (img_retinex * img_color + b)
    for i in range(img_msrcr.shape[2]):
        img_msrcr[:, :, i] = (img_msrcr[:, :, i] - np.min(img_msrcr[:, :, i])) / \
                             (np.max(img_msrcr[:, :, i]) - np.min(img_msrcr[:, :, i])) * \
                             255
    
    img_msrcr = np.uint8(np.minimum(np.maximum(img_msrcr, 0), 255))
    img_msrcr = simplestColorBalance(img_msrcr, low_clip, high_clip)       
    return img_msrcr

def automatedMSRCR(img, sigma_list):
    img = np.float64(img) + 1.0
    img_retinex = multiScaleRetinex(img, sigma_list)
    for i in range(img_retinex.shape[2]):
        unique, count = np.unique(np.int32(img_retinex[:, :, i] * 100), return_counts=True)
        for u, c in zip(unique, count):
            if u == 0:
                zero_count = c
                break
            
        low_val = unique[0] / 100.0
        high_val = unique[-1] / 100.0
        for u, c in zip(unique, count):
            if u < 0 and c < zero_count * 0.1:
                low_val = u / 100.0
            if u > 0 and c < zero_count * 0.1:
                high_val = u / 100.0
                break
        img_retinex[:, :, i] = np.maximum(np.minimum(img_retinex[:, :, i], high_val), low_val)
        
        img_retinex[:, :, i] = (img_retinex[:, :, i] - np.min(img_retinex[:, :, i])) / \
                               (np.max(img_retinex[:, :, i]) - np.min(img_retinex[:, :, i])) \
                               * 255
    img_retinex = np.uint8(img_retinex)
    return img_retinex

def MSRCP(img, sigma_list, low_clip, high_clip):
    img = np.float64(img) + 1.0
    intensity = np.sum(img, axis=2) / img.shape[2]    
    retinex = multiScaleRetinex(intensity, sigma_list)
    intensity = np.expand_dims(intensity, 2)
    retinex = np.expand_dims(retinex, 2)
    intensity1 = simplestColorBalance(retinex, low_clip, high_clip)
    intensity1 = (intensity1 - np.min(intensity1)) / \
                 (np.max(intensity1) - np.min(intensity1)) * \
                 255.0 + 1.0
    img_msrcp = np.zeros_like(img)
    for y in range(img_msrcp.shape[0]):
        for x in range(img_msrcp.shape[1]):
            B = np.max(img[y, x])
            A = np.minimum(256.0 / B, intensity1[y, x, 0] / intensity[y, x, 0])
            img_msrcp[y, x, 0] = A * img[y, x, 0]
            img_msrcp[y, x, 1] = A * img[y, x, 1]
            img_msrcp[y, x, 2] = A * img[y, x, 2]

    img_msrcp = np.uint8(img_msrcp - 1.0)
    return img_msrcp

#预训练

from datetime import datetime
import time
import logging

def facc(label, pred):
        pred = pred.ravel()
        label = label.ravel()
        return ((pred > 0.5) == label).mean()
def pre_train():
    metric = mx.metric.CustomMetric(facc)
    stamp =  datetime.now().strftime('%Y_%m_%d-%H_%M')
    logging.basicConfig(level=logging.DEBUG)

    for epoch in range(epochs):
        tic = time.time()
        btic = time.time()
        train_data.reset()
        iter = 0
        for batch in train_data:
            ############################
            # (1) Update D network: maximize log(D(x, y)) + log(1 - D(x, G(x, z)))
            ###########################
            real_in = batch.data[0].as_in_context(ctx)
            real_out = batch.data[1].as_in_context(ctx)
           
               
            with autograd.record():
                fake_out = netG1(real_in)
                errG1 = L1_loss(fake_out, real_out)*lambda1
                #errG1 = land_mark_errs(real_in, fake_out)
                errG1.backward()

            trainerG1.step(batch.data[0].shape[0])
             
            with autograd.record():
                fake_out2 = netG2(real_out)
                errG2 = L1_loss(fake_out2, real_in)*lambda1 
                errG2.backward()

            trainerG2.step(batch.data[0].shape[0])
       
            # Print log infomation every ten batches
            if iter % 10 == 0:
                name, acc = metric.get()
                logging.info('speed: {} samples/s'.format(batch_size / (time.time() - btic)))
                logging.info('G1generator1 loss = %f, binary training acc = %f at iter %d epoch %d'
                        %(nd.mean(errG1).asscalar(), acc, iter, epoch))
                logging.info('G1generator2 loss = %f, binary training acc = %f at iter %d epoch %d'
                         %(nd.mean(errG2).asscalar(), acc, iter, epoch))
           
            iter = iter + 1
            btic = time.time()

        name, acc = metric.get()
        metric.reset()
        logging.info('\nbinary training acc at epoch %d: %s=%f' % (epoch, name, acc))
        logging.info('time: %f' % (time.time() - tic))

        # Visualize one generated image for each epoch
        fake_img = fake_out[0]
        visualize(fake_img)
        plt.show()
        
        #fake_img2 = fake_out2[0]
        #visualize(fake_img2)
        #plt.show()

pre_train()
def save_data(path,tpath):
    img_in_list = []
    img_out_list = []
    for path, _, fnames in os.walk(path):
        for fname in fnames:
            if not fname.endswith('.jpg'):
                continue
            img = os.path.join(path, fname)
            img_arr = mx.image.imread(img).astype(np.float32)/127.5 - 1
            img_arr = mx.image.imresize(img_arr, img_wd * 2, img_ht)
            # Crop input and output images
            img_arr_in, img_arr_out = [mx.image.fixed_crop(img_arr, 0, 0, img_wd, img_ht),
                                       mx.image.fixed_crop(img_arr, img_wd, 0, img_wd, img_ht)]
            #img_arr_in = mx.image.imresize(img_arr_in,128,128)
            #img_arr_out = mx.image.imresize(img_arr_out,128,128)
            img_arr_in, img_arr_out = [nd.transpose(img_arr_in, (2,0,1)),
                                       nd.transpose(img_arr_out, (2,0,1))]
            img_arr_in, img_arr_out = [img_arr_in.reshape((1,) + img_arr_in.shape),
                                       img_arr_out.reshape((1,) + img_arr_out.shape)]
            img_out = netG1(img_arr_out.as_in_context(ctx))
            img_out1 = img_out[0]
            img_out2 = ((img_out1.asnumpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8)
                
            save_name = tpath+fname
            
            cv2.imwrite(save_name, img_out2)

save_data("../data/edges2handbags/val/","../data/edges2handbags/G1andG2/")
netD1 = Discriminator(in_channels=6)
netD2 = Discriminator(in_channels=6)
network_init(netD1)
network_init(netD2)
trainerD1 = gluon.Trainer(netD1.collect_params(), 'adam', {'learning_rate': lr, 'beta1': beta1})
trainerD2 = gluon.Trainer(netD2.collect_params(), 'adam', {'learning_rate': lr, 'beta1': beta1})
from datetime import datetime
import time
import logging
def facc(label, pred):
        pred = pred.ravel()
        label = label.ravel()
        return ((pred > 0.5) == label).mean()

def dual_pre_train():
    metric = mx.metric.CustomMetric(facc)
    stamp =  datetime.now().strftime('%Y_%m_%d-%H_%M')
    logging.basicConfig(level=logging.DEBUG)
    for epoch in range(epochs):
        tic = time.time()
        btic = time.time()
        PIE_normal_to_lighting.reset()
        iter = 0
        for (batch1, batch2)  in zip(retinex_data,PIE_normal_to_lighting):
            ############################
            # (1) Update D network: maximize log(D(x, y)) + log(1 - D(x, G(x, z)))
            ###########################
            real_in = batch1.data[0].as_in_context(ctx)
            real_out = batch1.data[1].as_in_context(ctx)
            lighing_bad = batch2.data[0].as_in_context(ctx) 
            lighing_good = batch2.data[1].as_in_context(ctx)
                      
                     
            with autograd.record():
                fake_out = netG1(real_in)
                #errG1 = L1_loss(real_out, fake_out) + L1_loss(netG1(netG2(fake_out)),real_out)
                errG1 = L1_loss(real_in, fake_out)+L1_loss(netG1(netG2(lighing_good)), lighing_good)
                #增加一个三方loss
                #errG1 = L1_loss(real_out, fake_out) + L1_loss(netG1(netG2(fake_out)),real_out) 
                           #                         + L1_loss(netG1(netG2(fake_out)),fake_out) 
                errG1.backward()

            trainerG1.step(batch1.data[0].shape[0])
            
            with autograd.record():
                fake_out3 = netG2(real_out)
                #errG2 = L1_loss(real_in, fake_out3) + L1_loss(netG2(netG1(fake_out3)),real_in)
                errG2 = L1_loss(lighing_good, fake_out3)+L1_loss(netG2(netG1(real_in)), real_in) 
                #增加一个三方loss
                #errG2 = L1_loss(real_in, fake_out3) + L1_loss(netG2(netG1(fake_out3)),real_in)
                           #                         + L1_loss(netG2(netG1(fake_out3)),fake_out3) 
                errG2.backward()

            trainerG2.step(batch2.data[0].shape[0])
            
            # Print log infomation every ten batches
            if iter % 10 == 0:
                name, acc = metric.get()
                logging.info('speed: {} samples/s'.format(batch_size / (time.time() - btic)))
                logging.info('G1generator loss = %f, binary training acc = %f at iter %d epoch %d'
                         %(nd.mean(errG1).asscalar(), acc, iter, epoch))
                logging.info('G2generator loss = %f, binary training acc = %f at iter %d epoch %d'
                         %(nd.mean(errG2).asscalar(), acc, iter, epoch))
            iter = iter + 1
            btic = time.time()

        name, acc = metric.get()
        metric.reset()
        logging.info('\nbinary training acc at epoch %d: %s=%f' % (epoch, name, acc))
        logging.info('time: %f' % (time.time() - tic))

        # Visualize one generated image for each epoch
        fake_img = fake_out[0]
        visualize(fake_img)
        plt.show()

dual_pre_train()
def test_netG(Spath,Tpath):
    for path, _, fnames in os.walk(Spath):
        for fname in fnames:
            if not fname.endswith('.png'):
                continue
            #num = fname[14:16]
            #if num !='07':
                #continue
            test_img = os.path.join(path, fname)
            img_fname = mx.image.imread(test_img) 
            img_arr_fname = img_fname.astype(np.float32)/127.5 - 1
            img_arr_fname = mx.image.imresize(img_arr_fname,128,128)
            img_arr_in = nd.transpose(img_arr_fname, (2,0,1))
            img_arr_in = img_arr_in.reshape((1,) + img_arr_in.shape)
            img_out = netG1(img_arr_in.as_in_context(ctx))
            img_out = img_out[0]
            #img_out = mx.image.imresize(img_out,120,165)
            save_name = Tpath+ fname
            plt.imsave(save_name, ((img_out.asnumpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8) )
            
#test_netG('MultiPIE/MultiPIE_test_128_Gray/','MultiPIE/relighting/')
test_netG('MultiPIE/Bio_relighing2/','MultiPIE/Bio_color/')

#使用opencv的人脸特征点作为损失

fileDir = '/home/hxj/gluon-tutorials/GAN/openface/'
sys.path.append(os.path.join(fileDir))
import argparse
import cv2
import dlib
import matplotlib.pyplot as plt
from pylab import plot  
from openface.align_dlib import AlignDlib
modelDir = os.path.join(fileDir, 'models')
openfaceModelDir = os.path.join(modelDir, 'openface')
dlibModelDir = os.path.join(modelDir, 'dlib')
dlibFacePredictor= os.path.join(dlibModelDir, "shape_predictor_68_face_landmarks.dat")

def land_mark_errs(batch1,batch2):
    align = AlignDlib(dlibFacePredictor)
    sum_err = nd.zeros((10)).as_in_context(ctx)
    i=0
    for (x,y) in zip(batch1,batch2):
        x1 = ((x.asnumpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8)
        y1 = ((y.asnumpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8)
        """
        bbx = align.getLargestFaceBoundingBox(x1)
        if bbx is None:
            x1_r = MSRCR(x1,[15, 80, 250], 5.0, 25.0, 125.0, 46.0, 0.01, 0.99)
            bbx = align.getLargestFaceBoundingBox(x1_r)
            if bbx is None:
                lab = cv2.cvtColor(x1,cv2.COLOR_RGB2LUV)
                bbx = align.getLargestFaceBoundingBox(lab)
                #if bbx is None:
                    #print('bbx is none')
                
        bby = align.getLargestFaceBoundingBox(y1)
        if bby is None:
            y1_r = MSRCR(y1,[15, 80, 250], 5.0, 25.0, 125.0, 46.0, 0.01, 0.99)
            bby = align.getLargestFaceBoundingBox(y1_r)
            if bby is None:
                lab = cv2.cvtColor(y1, cv2.COLOR_RGB2LUV)
                bby = align.getLargestFaceBoundingBox(lab)
                if bby is None:
                    #print('bby is none')
    
        if bby is None:
            continue
        if bbx is None:
            #bbx= bby
            continue
        """
        bbx = dlib.rectangle(-19, -19, 124, 125)
        bby = dlib.rectangle(-19, -19, 124, 125)
        landmarks_x = nd.array(align.findLandmarks(x1, bbx))
        landmarks_y = nd.array(align.findLandmarks(y1, bby))
        if landmarks_x  is None:
            continue
        if landmarks_y is None:
            continue
        sum_err[i]=nd.sum(nd.abs(landmarks_x -landmarks_y))/68
        i+=1
    return sum_err
from datetime import datetime
import time
import logging
def facc(label, pred):
        pred = pred.ravel()
        label = label.ravel()
        return ((pred > 0.5) == label).mean()

def generate_train_single():
    image_pool = ImagePool(pool_size)
    metric = mx.metric.CustomMetric(facc)
    stamp =  datetime.now().strftime('%Y_%m_%d-%H_%M')
    logging.basicConfig(level=logging.DEBUG)
    
    for epoch in range(epochs):
        tic = time.time()
        btic = time.time()
        train_data.reset()
        iter = 0
        for batch1  in train_data:
        #for batch in range(400):
            ############################
            # (1) Update D network: maximize log(D(x, y)) + log(1 - D(x, G(x, z)))
            ###########################
            real_in = batch1.data[0].as_in_context(ctx) #将train数据的输入和输出调出来
            real_out = batch1.data[1].as_in_context(ctx)
            
            #G1
            fake_out = netG1(real_in)
            with autograd.record():
                errG1 = L1_loss(real_out, fake_out)* 20 +  L1_loss(netG1(netG2(real_out)), real_out) *10
              
                #land_mark_errs(real_in, fake_out)*0.4
                errG1.backward()
                                                                                                    
            trainerG1.step(batch1.data[0].shape[0])
                       
            #G2
            fake_out2 = netG2(real_out)
            with autograd.record():
                errG2 = L1_loss(real_in, fake_out2)* 20 +  L1_loss(netG2(netG1(real_in)), real_in)*10
                
                #land_mark_errs(real_out, fake_out2)*0.4
            trainerG2.step(batch1.data[0].shape[0])
            
            # Print log infomation every ten batches
            if iter % 10 == 0:
                name, acc = metric.get()
                logging.info('speed: {} samples/s'.format(batch_size / (time.time() - btic)))
                logging.info('generator1 loss = %f, binary training acc = %f at iter %d epoch %d'
                         %(nd.mean(errG1).asscalar(), acc, iter, epoch))
                logging.info('generator2 loss = %f, binary training acc = %f at iter %d epoch %d'
                         %(nd.mean(errG2).asscalar(), acc, iter, epoch))
            iter = iter + 1
            btic = time.time()

        name, acc = metric.get()
        metric.reset()
        logging.info('\nbinary training acc at epoch %d: %s=%f' % (epoch, name, acc))
        logging.info('time: %f' % (time.time() - tic))

        # Visualize one generated image for each epoch
       
        fake_img = fake_out[0]
        visualize(fake_img)
        plt.show()
        
        

generate_train_single()
from skimage import io
bgrImg = cv2.imread('CAS/test_aligned_128/FM_000046_IFD+90_PM+00_EN_A0_D0_T0_BW_M0_R1_S0.png')
rgbImg = cv2.cvtColor(bgrImg, cv2.COLOR_BGR2RGB)
plt.imshow(rgbImg)
plt.show()
lab = cv2.cvtColor(bgrImg, cv2.COLOR_BGR2LAB)
plt.imshow(lab)
plt.show()

img_test = lab[:,:,0].astype(np.float32)/127.5 - 1
img_test = nd.array(img_test)
img_arr_in= img_test.reshape((1,1,) + img_test.shape).as_in_context(ctx)
test1 = netG1(img_arr_in)
test2 = test1[0][0]
cv2.imshow(((test2.asnumpy() + 1.0) * 127.5).astype(np.uint8))
from datetime import datetime
import time
import logging
def facc(label, pred):
        pred = pred.ravel()
        label = label.ravel()
        return ((pred > 0.5) == label).mean()

def Dual_train_single():
    image_pool = ImagePool(pool_size)
    metric = mx.metric.CustomMetric(facc)
    stamp =  datetime.now().strftime('%Y_%m_%d-%H_%M')
    logging.basicConfig(level=logging.DEBUG)
    
    for epoch in range(epochs):
        tic = time.time()
        btic = time.time()
        train_data.reset()
        iter = 0
        for batch1  in train_data:
        #for batch in range(400):
            ############################
            # (1) Update D network: maximize log(D(x, y)) + log(1 - D(x, G(x, z)))
            ###########################
            real_in = batch1.data[0].as_in_context(ctx) #将train数据的输入和输出调出来
            real_out = batch1.data[1].as_in_context(ctx)
                      
           
            #D1  
            fake_out = netG1(real_in)
            fake_concat = image_pool.query(nd.concat(real_in, fake_out, dim=1))
            with autograd.record():
                output = netD1(fake_concat)
                #output = netD1(fake_out)
                fake_label = nd.zeros(output.shape, ctx=ctx)
                errD_fake = GAN_loss(output, fake_label)
                metric.update([fake_label,], [output,])
                           
                # Train with real image
                real_concat = image_pool.query(nd.concat(real_in, real_out, dim=1))
                #ground truth 也要经过G1
                output = netD1(real_concat) 
                real_label = nd.ones(output.shape, ctx=ctx)
                errD_real = GAN_loss(output, real_label)
                
                errD1 = (errD_real + errD_fake) *0.5
                errD1.backward()
                metric.update([real_label,], [output,])

            trainerD1.step(batch1.data[0].shape[0])
           
            #G1
            with autograd.record():
                #fake_out = netG1(real_in)
                fake_concat = image_pool.query(nd.concat(real_in, fake_out, dim=1))
                output = netD1(fake_concat)
                #output = netD1(fake_out)
                real_label = nd.ones(output.shape, ctx=ctx)
                #errG1 = GAN_loss(output, real_label) + L1_loss(real_out, fake_out) * lambda1+ \
                #L1_loss(netG2(netG1(real_in)), real_in) * lambda1
                #errG1 = GAN_loss(output, real_label) + L1_loss(real_in, fake_out) * lambda1+ \
                #L1_loss(netG1(netG2(fake_out)), fake_out) * lambda1
                                              
                errG1 = GAN_loss(output, real_label) + L1_loss(real_out, fake_out) * 20+ \
                L1_loss(netG1(netG2(real_out)), real_out) *10
                #land_mark_errs(real_out, fake_out)
            
                errG1.backward()
                                                                                                     
            trainerG1.step(batch1.data[0].shape[0])
            
           
            #D2  
            fake_out2 = netG2(real_out)
            fake_concat2 = image_pool.query(nd.concat(real_out, fake_out2, dim=1))
            with autograd.record():
                output2 = netD2(fake_concat2)
                fake_label2 = nd.zeros(output2.shape, ctx=ctx)
                errD_fake2 = GAN_loss(output2, fake_label2)
                metric.update([fake_label2,], [output2,])
                           
                # Train with real image
                real_concat2 = image_pool.query(nd.concat(real_out, real_in, dim=1))
                output2 = netD2(real_concat2)
                real_label2 = nd.ones(output2.shape, ctx=ctx)
                errD_real2 = GAN_loss(output2, real_label2)
                
                errD2 = (errD_real2 + errD_fake2) * 0.5 
                errD2.backward()
                metric.update([real_label2,], [output2,])

            trainerD2.step(batch1.data[0].shape[0])
           
            #G2   
            with autograd.record():
                #fake_out2 = netG2(real_out)
                fake_concat2 = image_pool.query(nd.concat(real_out, fake_out2, dim=1))
                output2 = netD2(fake_concat2)
                real_label2 = nd.ones(output2.shape, ctx=ctx)
                
            
                #errG2 = GAN_loss(output2, real_label2)+ L1_loss(real_in, fake_out2) * lambda1+ \
                #L1_loss(netG1(netG2(lighing_good)), lighing_good) * lambda1
                errG2 = GAN_loss(output2, real_label2)+ L1_loss(real_in, fake_out2) * 20+ \
                L1_loss(netG2(netG1(real_in)), real_in) *10
                #land_mark_errs(real_in, fake_out2)
                errG2.backward()
                
            trainerG2.step(batch1.data[0].shape[0])
            
            # Print log infomation every ten batches
            if iter % 10 == 0:
                name, acc = metric.get()
                logging.info('speed: {} samples/s'.format(batch_size / (time.time() - btic)))
                logging.info('discriminator1 loss = %f, generator1 loss = %f, binary training acc = %f at iter %d epoch %d'
                         %(nd.mean(errD1).asscalar(),
                           nd.mean(errG1).asscalar(), acc, iter, epoch))
                logging.info('discriminator2 loss = %f, generator2 loss = %f, binary training acc = %f at iter %d epoch %d'
                         %(nd.mean(errD2).asscalar(),
                           nd.mean(errG2).asscalar(), acc, iter, epoch))
            iter = iter + 1
            btic = time.time()

        name, acc = metric.get()
        metric.reset()
        logging.info('\nbinary training acc at epoch %d: %s=%f' % (epoch, name, acc))
        logging.info('time: %f' % (time.time() - tic))

        # Visualize one generated image for each epoch
       
        fake_img = fake_out[0]
        visualize(fake_img)
        plt.show()
        
        

Dual_train_single()
from datetime import datetime
import time
import logging
def facc(label, pred):
        pred = pred.ravel()
        label = label.ravel()
        return ((pred > 0.5) == label).mean()

def train():
    #image_pool = ImagePool(pool_size)
    metric = mx.metric.CustomMetric(facc)
    stamp =  datetime.now().strftime('%Y_%m_%d-%H_%M')
    logging.basicConfig(level=logging.DEBUG)
    
    for epoch in range(epochs):
        tic = time.time()
        btic = time.time()
        retinex_data.reset()
        PIE_normal_to_lighting.reset()
        iter = 0
        for (batch1, batch2)  in zip(retinex_data,PIE_normal_to_lighting):
        #for batch in range(400):
            ############################
            # (1) Update D network: maximize log(D(x, y)) + log(1 - D(x, G(x, z)))
            ###########################
            real_in = batch1.data[0].as_in_context(ctx) #将train数据的输入和输出调出来
            real_out = batch1.data[1].as_in_context(ctx)
            lighing_bad = batch2.data[0].as_in_context(ctx) 
            lighing_good = batch2.data[1].as_in_context(ctx)
                      
            
            fake_out = netG1(real_in)
            #D1  
            with autograd.record():
                
                #fake_concat = image_pool.query(nd.concat(real_in, fake_out, dim=1))
                #output = netD1(fake_concat)
                output = netD1(fake_out)
                fake_label = nd.zeros(output.shape, ctx=ctx)
                errD_fake = GAN_loss(output, fake_label)
                metric.update([fake_label,], [output,])
                           
                # Train with real image
                #real_concat = image_pool.query(nd.concat(real_in, lighing_good, dim=1))
                output = netD1(lighing_good)
                real_label = nd.ones(output.shape, ctx=ctx)
                errD_real = GAN_loss(output, real_label)
                
                errD1 = (errD_real + errD_fake) * 0.5 
                errD1.backward()
                metric.update([real_label,], [output,])

            trainerD1.step(batch1.data[0].shape[0])
           
            #G1
            with autograd.record():
                #fake_concat = image_pool.query(nd.concat(real_in, fake_out, dim=1))
                #output = netD1(fake_concat)
                fake_out = netG1(real_in)
                output = netD1(fake_out)
                real_label = nd.ones(output.shape, ctx=ctx)
                #errG1 = GAN_loss(output, real_label) + L1_loss(real_out, fake_out) * lambda1+ \
                #L1_loss(netG2(netG1(real_in)), real_in) * lambda1
                #errG1 = GAN_loss(output, real_label) + L1_loss(real_in, fake_out) * lambda1+ \
                #L1_loss(netG1(netG2(fake_out)), fake_out) * lambda1
                errG1 = GAN_loss(output, real_label) + L1_loss(real_in, fake_out) * lambda1+ \
                L1_loss(netG1(netG2(lighing_good)), lighing_good) * lambda1
                errG1.backward()
                                                                                                     
            trainerG1.step(batch1.data[0].shape[0])
            
           
            #D2  
            fake_out2 = netG2(lighing_good)
            with autograd.record():
                #fake_concat2 = image_pool.query(nd.concat(lighing_good, fake_out2, dim=1))
                output2 = netD2(fake_out2)
                fake_label2 = nd.zeros(output2.shape, ctx=ctx)
                errD_fake2 = GAN_loss(output2, fake_label2)
                metric.update([fake_label2,], [output2,])
                           
                # Train with real image
                #real_concat2 = image_pool.query(nd.concat(lighing_good, real_in, dim=1))
                output2 = netD2(real_in)
                real_label2 = nd.ones(output2.shape, ctx=ctx)
                errD_real2 = GAN_loss(output2, real_label2)
                
                errD2 = (errD_real2 + errD_fake2) * 0.5 
                errD2.backward()
                metric.update([real_label2,], [output2,])

            trainerD2.step(batch2.data[0].shape[0])
           
            #G2   
            with autograd.record():
                fake_out2 = netG2(lighing_good)
                #fake_concat2 = image_pool.query(nd.concat(lighing_good, fake_out2, dim=1))
                output2 = netD2(fake_out2)
                real_label2 = nd.ones(output2.shape, ctx=ctx)
              
                #errG2 = GAN_loss(output2, real_label2)+ L1_loss(real_in, fake_out2) * lambda1+ \
                #L1_loss(netG1(netG2(lighing_good)), lighing_good) * lambda1
                errG2 = GAN_loss(output2, real_label2)+ L1_loss(lighing_good, fake_out2) * lambda1+ \
                L1_loss(netG2(netG1(real_in)), real_in) * lambda1
                errG2.backward()
                
            trainerG2.step(batch2.data[0].shape[0])
            
            # Print log infomation every ten batches
            if iter % 10 == 0:
                name, acc = metric.get()
                logging.info('speed: {} samples/s'.format(batch_size / (time.time() - btic)))
                logging.info('discriminator1 loss = %f, generator1 loss = %f, binary training acc = %f at iter %d epoch %d'
                         %(nd.mean(errD1).asscalar(),
                           nd.mean(errG1).asscalar(), acc, iter, epoch))
                logging.info('discriminator2 loss = %f, generator2 loss = %f, binary training acc = %f at iter %d epoch %d'
                         %(nd.mean(errD2).asscalar(),
                           nd.mean(errG2).asscalar(), acc, iter, epoch))
            iter = iter + 1
            btic = time.time()

        name, acc = metric.get()
        metric.reset()
        logging.info('\nbinary training acc at epoch %d: %s=%f' % (epoch, name, acc))
        logging.info('time: %f' % (time.time() - tic))

        # Visualize one generated image for each epoch
       
        fake_img = fake_out[0]
        visualize(fake_img)
        plt.show()
        
        

train()

 

posted @ 2018-08-15 10:50  白菜hxj  阅读(388)  评论(0编辑  收藏  举报