dual_GAN实现 (使用tensorflow)

具体实现地址:https://github.com/codehxj/DualGAN

以下是改编成notebook版本

import sys; 
sys.path.append("/home/hxj/anaconda3/lib/python3.6/site-packages")
from __future__ import division
import math
import json
import random
import pprint
import scipy.misc
import numpy as np
import os
from time import gmtime, strftime

#pp = pprint.PrettyPrinter()

#get_stddev = lambda x, k_h, k_w: 1/math.sqrt(k_w*k_h*x.get_shape()[-1])

    
def load_data(image_path, flip=False, is_test=False, image_size = 128):
    img = load_image(image_path)
    img = preprocess_img(img, img_size=image_size, flip=flip, is_test=is_test)

    img = img/127.5 - 1.
    if len(img.shape)<3:
        img = np.expand_dims(img, axis=2)
    return img

def load_image(image_path):
    img = imread(image_path)
    return img

def preprocess_img(img, img_size=128, flip=False, is_test=False):
    img = scipy.misc.imresize(img, [img_size, img_size])
    if (not is_test) and flip and np.random.random() > 0.5:
        img = np.fliplr(img)
    return img

def get_image(image_path, image_size, is_crop=True, resize_w=64, is_grayscale = False):
    return transform(imread(image_path, is_grayscale), image_size, is_crop, resize_w)

def save_images(images, size, image_path):
    dir = os.path.dirname(image_path)
    if not os.path.exists(dir):
        os.makedirs(dir)
    return imsave(inverse_transform(images), size, image_path)

def imread(path, is_grayscale = False):
    if (is_grayscale):
        return scipy.misc.imread(path, flatten = True)#.astype(np.float)
    else:
        return scipy.misc.imread(path)#.astype(np.float)

def merge_images(images, size):
    return inverse_transform(images)

def merge(images, size):
    h, w = images.shape[1], images.shape[2]
    if len(images.shape) < 4:
        img = np.zeros((h * size[0], w * size[1], 1))
        images = np.expand_dims(images, axis = 3)
    else:
        img = np.zeros((h * size[0], w * size[1], images.shape[3]))
    for idx, image in enumerate(images):
        i = idx % size[1]
        j = idx // size[1]
        img[j*h:j*h+h, i*w:i*w+w, :] = image
    if images.shape[3] ==1:
        return np.concatenate([img,img,img],axis=2)
    else:
        return img.astype(np.uint8)

def imsave(images, size, path):
    return scipy.misc.imsave(path, merge(images, size))

def transform(image, npx=64, is_crop=True, resize_w=64):
    # npx : # of pixels width/height of image
    if is_crop:
        cropped_image = center_crop(image, npx, resize_w=resize_w)
    else:
        cropped_image = image
    return np.array(cropped_image)/127.5 - 1.

def inverse_transform(images):
    return ((images+1.)*127.5)
import tensorflow as tf
from tensorflow.python.framework import ops
from utils import *
def batch_norm(x,  name="batch_norm"):
    eps = 1e-6
    with tf.variable_scope(name):
        nchannels = x.get_shape()[3]
        scale = tf.get_variable("scale", [nchannels], initializer=tf.random_normal_initializer(1.0, 0.02, dtype=tf.float32))
        center = tf.get_variable("center", [nchannels], initializer=tf.constant_initializer(0.0, dtype = tf.float32))
        ave, dev = tf.nn.moments(x, axes=[1,2], keep_dims=True)
        inv_dev = tf.rsqrt(dev + eps)
        normalized = (x-ave)*inv_dev * scale + center
        return normalized

def conv2d(input_, output_dim, 
           k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02,
           name="conv2d"):
    with tf.variable_scope(name):
        w = tf.get_variable('w', [k_h, k_w, input_.get_shape()[-1], output_dim],
                            initializer=tf.truncated_normal_initializer(stddev=stddev))
        conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding='SAME')

        biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0))
        conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape())

        return conv

def deconv2d(input_, output_shape,
             k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02,
             name="deconv2d", with_w=False):
    with tf.variable_scope(name):
        # filter : [height, width, output_channels, in_channels]
        w = tf.get_variable('w', [k_h, k_w, output_shape[-1], input_.get_shape()[-1]],
                            initializer=tf.random_normal_initializer(stddev=stddev))
        try:
            deconv = tf.nn.conv2d_transpose(input_, w, output_shape=output_shape,
                                strides=[1, d_h, d_w, 1])

        # Support for verisons of TensorFlow before 0.7.0
        except AttributeError:
            deconv = tf.nn.deconv2d(input_, w, output_shape=output_shape,
                                strides=[1, d_h, d_w, 1])

        biases = tf.get_variable('biases', [output_shape[-1]], initializer=tf.constant_initializer(0.0))
        deconv = tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape())

        if with_w:
            return deconv, w, biases
        else:
            return deconv

def lrelu(x, leak=0.2, name="lrelu"):
    return tf.maximum(x, leak*x)

def celoss(logits, labels):
    return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels))
       
from __future__ import division
import os
import time
from glob import glob
import tensorflow as tf
import numpy as np
from six.moves import xrange

#from ops import *
#from utils import *

class DualNet(object):
    def __init__(self, sess, image_size=256, batch_size=1,fcn_filter_dim = 64,  \
                 A_channels = 3, B_channels = 3, dataset_name='', \
                 checkpoint_dir=None, lambda_A = 20., lambda_B = 20., \
                 sample_dir=None, loss_metric = 'L1', flip = False):
        self.df_dim = fcn_filter_dim
        self.flip = flip
        self.lambda_A = lambda_A
        self.lambda_B = lambda_B
        
        self.sess = sess
        self.is_grayscale_A = (A_channels == 1)
        self.is_grayscale_B = (B_channels == 1)
        self.batch_size = batch_size
        self.image_size = image_size
        self.fcn_filter_dim = fcn_filter_dim
        self.A_channels = A_channels
        self.B_channels = B_channels
        self.loss_metric = loss_metric

        self.dataset_name = dataset_name
        self.checkpoint_dir = checkpoint_dir
        
        #directory name for output and logs saving
        self.dir_name = "%s-img_sz_%s-fltr_dim_%d-%s-lambda_AB_%s_%s" % (
                    self.dataset_name, 
                    self.image_size,
                    self.fcn_filter_dim,
                    self.loss_metric, 
                    self.lambda_A, 
                    self.lambda_B
                ) 
        self.build_model()

    def build_model(self):
    ###    define place holders
        self.real_A = tf.placeholder(tf.float32,[self.batch_size, self.image_size, self.image_size,
                                         self.A_channels ],name='real_A')
        self.real_B = tf.placeholder(tf.float32, [self.batch_size, self.image_size, self.image_size,
                                         self.B_channels ], name='real_B')
        
    ###  define graphs
        self.A2B = self.A_g_net(self.real_A, reuse = False)
        self.B2A = self.B_g_net(self.real_B, reuse = False)
        self.A2B2A = self.B_g_net(self.A2B, reuse = True)
        self.B2A2B = self.A_g_net(self.B2A, reuse = True)
        
        if self.loss_metric == 'L1':
            self.A_loss = tf.reduce_mean(tf.abs(self.A2B2A - self.real_A))
            self.B_loss = tf.reduce_mean(tf.abs(self.B2A2B - self.real_B))
        elif self.loss_metric == 'L2':
            self.A_loss = tf.reduce_mean(tf.square(self.A2B2A - self.real_A))
            self.B_loss = tf.reduce_mean(tf.square(self.B2A2B - self.real_B))
        
        self.Ad_logits_fake = self.A_d_net(self.A2B, reuse = False)
        self.Ad_logits_real = self.A_d_net(self.real_B, reuse = True)
        self.Ad_loss_real = celoss(self.Ad_logits_real, tf.ones_like(self.Ad_logits_real))
        self.Ad_loss_fake = celoss(self.Ad_logits_fake, tf.zeros_like(self.Ad_logits_fake))
        self.Ad_loss = self.Ad_loss_fake + self.Ad_loss_real
        self.Ag_loss = celoss(self.Ad_logits_fake, labels=tf.ones_like(self.Ad_logits_fake))+self.lambda_B * (self.B_loss )

        self.Bd_logits_fake = self.B_d_net(self.B2A, reuse = False)
        self.Bd_logits_real = self.B_d_net(self.real_A, reuse = True)
        self.Bd_loss_real = celoss(self.Bd_logits_real, tf.ones_like(self.Bd_logits_real))
        self.Bd_loss_fake = celoss(self.Bd_logits_fake, tf.zeros_like(self.Bd_logits_fake))
        self.Bd_loss = self.Bd_loss_fake + self.Bd_loss_real
        self.Bg_loss = celoss(self.Bd_logits_fake, tf.ones_like(self.Bd_logits_fake))+self.lambda_A * (self.A_loss)
       
        self.d_loss = self.Ad_loss + self.Bd_loss
        self.g_loss = self.Ag_loss + self.Bg_loss
        ## define trainable variables
        t_vars = tf.trainable_variables()
        self.A_d_vars = [var for var in t_vars if 'A_d_' in var.name]
        self.B_d_vars = [var for var in t_vars if 'B_d_' in var.name]
        self.A_g_vars = [var for var in t_vars if 'A_g_' in var.name]
        self.B_g_vars = [var for var in t_vars if 'B_g_' in var.name]
        self.d_vars = self.A_d_vars + self.B_d_vars 
        self.g_vars = self.A_g_vars + self.B_g_vars
        self.saver = tf.train.Saver()

    def clip_trainable_vars(self, var_list):
        for var in var_list:
            self.sess.run(var.assign(tf.clip_by_value(var, -self.c, self.c)))

    def load_random_samples(self):
        #np.random.choice(
        sample_files =np.random.choice(glob('./datasets/{}/val/A/*.jpg'.format(self.dataset_name)),self.batch_size)
        sample_A_imgs = [load_data(f, image_size =self.image_size, flip = False) for f in sample_files]
        
        sample_files = np.random.choice(glob('./datasets/{}/val/B/*.jpg'.format(self.dataset_name)),self.batch_size)
        sample_B_imgs = [load_data(f, image_size =self.image_size, flip = False) for f in sample_files]

        sample_A_imgs = np.reshape(np.array(sample_A_imgs).astype(np.float32),(self.batch_size,self.image_size, self.image_size,-1))
        sample_B_imgs = np.reshape(np.array(sample_B_imgs).astype(np.float32),(self.batch_size,self.image_size, self.image_size,-1))
        return sample_A_imgs, sample_B_imgs

    def sample_shotcut(self, sample_dir, epoch_idx, batch_idx):
        sample_A_imgs,sample_B_imgs = self.load_random_samples()
        
        Ag, A2B2A_imgs, A2B_imgs = self.sess.run([self.A_loss, self.A2B2A, self.A2B], feed_dict={self.real_A: sample_A_imgs, self.real_B: sample_B_imgs})
        Bg, B2A2B_imgs, B2A_imgs = self.sess.run([self.B_loss, self.B2A2B, self.B2A], feed_dict={self.real_A: sample_A_imgs, self.real_B: sample_B_imgs})

        save_images(A2B_imgs, [self.batch_size,1], './{}/{}/{:06d}_{:04d}_A2B.jpg'.format(sample_dir,self.dir_name , epoch_idx, batch_idx))
        save_images(A2B2A_imgs, [self.batch_size,1],    './{}/{}/{:06d}_{:04d}_A2B2A.jpg'.format(sample_dir,self.dir_name, epoch_idx,  batch_idx))
        
        save_images(B2A_imgs, [self.batch_size,1], './{}/{}/{:06d}_{:04d}_B2A.jpg'.format(sample_dir,self.dir_name, epoch_idx, batch_idx))
        save_images(B2A2B_imgs, [self.batch_size,1], './{}/{}/{:06d}_{:04d}_B2A2B.jpg'.format(sample_dir,self.dir_name, epoch_idx, batch_idx))
        
        print("[Sample] A_loss: {:.8f}, B_loss: {:.8f}".format(Ag, Bg))

    def train(self, args):
        """Train Dual GAN"""
        decay = 0.9
        self.d_optim = tf.train.RMSPropOptimizer(args.lr, decay=decay) \
                          .minimize(self.d_loss, var_list=self.d_vars)
                          
        self.g_optim = tf.train.RMSPropOptimizer(args.lr, decay=decay) \
                          .minimize(self.g_loss, var_list=self.g_vars)          
        tf.global_variables_initializer().run()

        self.writer = tf.summary.FileWriter("./logs/"+self.dir_name, self.sess.graph)

        step = 1
        start_time = time.time()

        if self.load(self.checkpoint_dir):
            print(" [*] Load SUCCESS")
        else:
            print(" Load failed...ignored...")
            print(" start training...")

        for epoch_idx in xrange(args.epoch):
            data_A = glob('./datasets/{}/train/A/*.jpg'.format(self.dataset_name))
            data_B = glob('./datasets/{}/train/B/*.jpg'.format(self.dataset_name))
            np.random.shuffle(data_A)
            np.random.shuffle(data_B)
            epoch_size = min(len(data_A), len(data_B)) // (self.batch_size)
            print('[*] training data loaded successfully')
            print("#data_A: %d  #data_B:%d" %(len(data_A),len(data_B)))
            print('[*] run optimizor...')

            for batch_idx in xrange(0, epoch_size):
                imgA_batch = self.load_training_imgs(data_A, batch_idx)
                imgB_batch = self.load_training_imgs(data_B, batch_idx)
                
                print("Epoch: [%2d] [%4d/%4d]"%(epoch_idx, batch_idx, epoch_size))
                step = step + 1
                self.run_optim(imgA_batch, imgB_batch, step, start_time)

                if np.mod(step, 100) == 1:
                    self.sample_shotcut(args.sample_dir, epoch_idx, batch_idx)

                if np.mod(step, args.save_freq) == 2:
                    self.save(args.checkpoint_dir, step)

    def load_training_imgs(self, files, idx):
        batch_files = files[idx*self.batch_size:(idx+1)*self.batch_size]
        batch_imgs = [load_data(f, image_size =self.image_size, flip = self.flip) for f in batch_files]
                
        batch_imgs = np.reshape(np.array(batch_imgs).astype(np.float32),(self.batch_size,self.image_size, self.image_size,-1))
        
        return batch_imgs
        
    def run_optim(self,batch_A_imgs, batch_B_imgs,  counter, start_time):
        _, Adfake,Adreal,Bdfake,Bdreal, Ad, Bd = self.sess.run(
            [self.d_optim, self.Ad_loss_fake, self.Ad_loss_real, self.Bd_loss_fake, self.Bd_loss_real, self.Ad_loss, self.Bd_loss], 
            feed_dict = {self.real_A: batch_A_imgs, self.real_B: batch_B_imgs})
        _, Ag, Bg, Aloss, Bloss = self.sess.run(
            [self.g_optim, self.Ag_loss, self.Bg_loss, self.A_loss, self.B_loss], 
            feed_dict={ self.real_A: batch_A_imgs, self.real_B: batch_B_imgs})

        _, Ag, Bg, Aloss, Bloss = self.sess.run(
            [self.g_optim, self.Ag_loss, self.Bg_loss, self.A_loss, self.B_loss], 
            feed_dict={ self.real_A: batch_A_imgs, self.real_B: batch_B_imgs})

        print("time: %4.4f, Ad: %.2f, Ag: %.2f, Bd: %.2f, Bg: %.2f,  U_diff: %.5f, V_diff: %.5f" \
                    % (time.time() - start_time, Ad,Ag,Bd,Bg, Aloss, Bloss))
        print("Ad_fake: %.2f, Ad_real: %.2f, Bd_fake: %.2f, Bg_real: %.2f" % (Adfake,Adreal,Bdfake,Bdreal))

    def A_d_net(self, imgs, y = None, reuse = False):
        return self.discriminator(imgs, prefix = 'A_d_', reuse = reuse)
    
    def B_d_net(self, imgs, y = None, reuse = False):
        return self.discriminator(imgs, prefix = 'B_d_', reuse = reuse)
        
    def discriminator(self, image,  y=None, prefix='A_d_', reuse=False):
        # image is 256 x 256 x (input_c_dim + output_c_dim)
        with tf.variable_scope(tf.get_variable_scope()) as scope:
            if reuse:
                scope.reuse_variables()
            else:
                assert scope.reuse == False

            h0 = lrelu(conv2d(image, self.df_dim, name=prefix+'h0_conv'))
            # h0 is (128 x 128 x self.df_dim)
            h1 = lrelu(batch_norm(conv2d(h0, self.df_dim*2, name=prefix+'h1_conv'), name = prefix+'bn1'))
            # h1 is (64 x 64 x self.df_dim*2)
            h2 = lrelu(batch_norm(conv2d(h1, self.df_dim*4, name=prefix+'h2_conv'), name = prefix+ 'bn2'))
            # h2 is (32x 32 x self.df_dim*4)
            h3 = lrelu(batch_norm(conv2d(h2, self.df_dim*8, d_h=1, d_w=1, name=prefix+'h3_conv'), name = prefix+ 'bn3'))
            # h3 is (32 x 32 x self.df_dim*8)
            h4 = conv2d(h3, 1, d_h=1, d_w=1, name =prefix+'h4')
            return h4
        
    def A_g_net(self, imgs, reuse=False):
        return self.fcn(imgs, prefix='A_g_', reuse = reuse)
        

    def B_g_net(self, imgs, reuse=False):
        return self.fcn(imgs, prefix = 'B_g_', reuse = reuse)
        
    def fcn(self, imgs, prefix=None, reuse = False):
        with tf.variable_scope(tf.get_variable_scope()) as scope:
            if reuse:
                scope.reuse_variables()
            else:
                assert scope.reuse == False
            
            s = self.image_size
            s2, s4, s8, s16, s32, s64, s128 = int(s/2), int(s/4), int(s/8), int(s/16), int(s/32), int(s/64), int(s/128)

            # imgs is (256 x 256 x input_c_dim)
            e1 = conv2d(imgs, self.fcn_filter_dim, name=prefix+'e1_conv')
            # e1 is (128 x 128 x self.fcn_filter_dim)
            e2 = batch_norm(conv2d(lrelu(e1), self.fcn_filter_dim*2, name=prefix+'e2_conv'), name = prefix+'bn_e2')
            # e2 is (64 x 64 x self.fcn_filter_dim*2)
            e3 = batch_norm(conv2d(lrelu(e2), self.fcn_filter_dim*4, name=prefix+'e3_conv'), name = prefix+'bn_e3')
            # e3 is (32 x 32 x self.fcn_filter_dim*4)
            e4 = batch_norm(conv2d(lrelu(e3), self.fcn_filter_dim*8, name=prefix+'e4_conv'), name = prefix+'bn_e4')
            # e4 is (16 x 16 x self.fcn_filter_dim*8)
            e5 = batch_norm(conv2d(lrelu(e4), self.fcn_filter_dim*8, name=prefix+'e5_conv'), name = prefix+'bn_e5')
            # e5 is (8 x 8 x self.fcn_filter_dim*8)
            e6 = batch_norm(conv2d(lrelu(e5), self.fcn_filter_dim*8, name=prefix+'e6_conv'), name = prefix+'bn_e6')
            # e6 is (4 x 4 x self.fcn_filter_dim*8)
            e7 = batch_norm(conv2d(lrelu(e6), self.fcn_filter_dim*8, name=prefix+'e7_conv'), name = prefix+'bn_e7')
            # e7 is (2 x 2 x self.fcn_filter_dim*8)
            e8 = batch_norm(conv2d(lrelu(e7), self.fcn_filter_dim*8, name=prefix+'e8_conv'), name = prefix+'bn_e8')
            # e8 is (1 x 1 x self.fcn_filter_dim*8)

            self.d1, self.d1_w, self.d1_b = deconv2d(tf.nn.relu(e8),
                [self.batch_size, s128, s128, self.fcn_filter_dim*8], name=prefix+'d1', with_w=True)
            d1 = tf.nn.dropout(batch_norm(self.d1, name = prefix+'bn_d1'), 0.5)
            d1 = tf.concat([d1, e7],3)
            # d1 is (2 x 2 x self.fcn_filter_dim*8*2)

            self.d2, self.d2_w, self.d2_b = deconv2d(tf.nn.relu(d1),
                [self.batch_size, s64, s64, self.fcn_filter_dim*8], name=prefix+'d2', with_w=True)
            d2 = tf.nn.dropout(batch_norm(self.d2, name = prefix+'bn_d2'), 0.5)

            d2 = tf.concat([d2, e6],3)
            # d2 is (4 x 4 x self.fcn_filter_dim*8*2)

            self.d3, self.d3_w, self.d3_b = deconv2d(tf.nn.relu(d2),
                [self.batch_size, s32, s32, self.fcn_filter_dim*8], name=prefix+'d3', with_w=True)
            d3 = tf.nn.dropout(batch_norm(self.d3, name = prefix+'bn_d3'), 0.5)

            d3 = tf.concat([d3, e5],3)
            # d3 is (8 x 8 x self.fcn_filter_dim*8*2)

            self.d4, self.d4_w, self.d4_b = deconv2d(tf.nn.relu(d3),
                [self.batch_size, s16, s16, self.fcn_filter_dim*8], name=prefix+'d4', with_w=True)
            d4 = batch_norm(self.d4, name = prefix+'bn_d4')

            d4 = tf.concat([d4, e4],3)
            # d4 is (16 x 16 x self.fcn_filter_dim*8*2)

            self.d5, self.d5_w, self.d5_b = deconv2d(tf.nn.relu(d4),
                [self.batch_size, s8, s8, self.fcn_filter_dim*4], name=prefix+'d5', with_w=True)
            d5 = batch_norm(self.d5, name = prefix+'bn_d5')
            d5 = tf.concat([d5, e3],3)
            # d5 is (32 x 32 x self.fcn_filter_dim*4*2)

            self.d6, self.d6_w, self.d6_b = deconv2d(tf.nn.relu(d5),
                [self.batch_size, s4, s4, self.fcn_filter_dim*2], name=prefix+'d6', with_w=True)
            d6 = batch_norm(self.d6, name = prefix+'bn_d6')
            d6 = tf.concat([d6, e2],3)
            # d6 is (64 x 64 x self.fcn_filter_dim*2*2)

            self.d7, self.d7_w, self.d7_b = deconv2d(tf.nn.relu(d6),
                [self.batch_size, s2, s2, self.fcn_filter_dim], name=prefix+'d7', with_w=True)
            d7 = batch_norm(self.d7, name = prefix+'bn_d7')
            d7 = tf.concat([d7, e1],3)
            # d7 is (128 x 128 x self.fcn_filter_dim*1*2)

            if prefix == 'B_g_':
                self.d8, self.d8_w, self.d8_b = deconv2d(tf.nn.relu(d7),[self.batch_size, s, s, self.A_channels], name=prefix+'d8', with_w=True)
            elif prefix == 'A_g_':
                self.d8, self.d8_w, self.d8_b = deconv2d(tf.nn.relu(d7),[self.batch_size, s, s, self.B_channels], name=prefix+'d8', with_w=True)
             # d8 is (256 x 256 x output_c_dim)
            return tf.nn.tanh(self.d8)
    
    def save(self, checkpoint_dir, step):
        model_name = "DualNet.model"
        model_dir = self.dir_name
        checkpoint_dir = os.path.join(checkpoint_dir, model_dir)

        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)

        self.saver.save(self.sess,
                        os.path.join(checkpoint_dir, model_name),
                        global_step=step)

    def load(self, checkpoint_dir):
        print(" [*] Reading checkpoint...")

        model_dir =  self.dir_name
        checkpoint_dir = os.path.join(checkpoint_dir, model_dir)

        ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
            self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
            return True
        else:
            return False

    def test(self, args):
        """Test DualNet"""
        start_time = time.time()
        tf.global_variables_initializer().run()
        if self.load(self.checkpoint_dir):
            print(" [*] Load SUCCESS")
            test_dir = './{}/{}'.format(args.test_dir, self.dir_name)
            if not os.path.exists(test_dir):
                os.makedirs(test_dir)
            test_log = open(test_dir+'evaluation.txt','a') 
            test_log.write(self.dir_name)
            self.test_domain(args, test_log, type = 'A')
            self.test_domain(args, test_log, type = 'B')
            test_log.close()
        
    def test_domain(self, args, test_log, type = 'A'):
        test_files = glob('./datasets/{}/val/{}/*.jpg'.format(self.dataset_name,type))
        # load testing input
        print("Loading testing images ...")
        test_imgs = [load_data(f, is_test=True, image_size =self.image_size, flip = args.flip) for f in test_files]
        print("#images loaded: %d"%(len(test_imgs)))
        test_imgs = np.reshape(np.asarray(test_imgs).astype(np.float32),(len(test_files),self.image_size, self.image_size,-1))
        test_imgs = [test_imgs[i*self.batch_size:(i+1)*self.batch_size]
                         for i in xrange(0, len(test_imgs)//self.batch_size)]
        test_imgs = np.asarray(test_imgs)
        test_path = './{}/{}/'.format(args.test_dir, self.dir_name)
        # test input samples
        if type == 'A':
            for i in xrange(0, len(test_files)//self.batch_size):
                filename_o = test_files[i*self.batch_size].split('/')[-1].split('.')[0]
                print(filename_o)
                idx = i+1
                A_imgs = np.reshape(np.array(test_imgs[i]), (self.batch_size,self.image_size, self.image_size,-1))
                print("testing A image %d"%(idx))
                print(A_imgs.shape)
                A2B_imgs, A2B2A_imgs = self.sess.run(
                    [self.A2B, self.A2B2A],
                    feed_dict={self.real_A: A_imgs}
                    )
                save_images(A_imgs, [self.batch_size, 1], test_path+filename_o+'_realA.jpg')
                save_images(A2B_imgs, [self.batch_size, 1], test_path+filename_o+'_A2B.jpg')
                save_images(A2B2A_imgs, [self.batch_size, 1], test_path+filename_o+'_A2B2A.jpg')
        elif type=='B':
            for i in xrange(0, len(test_files)//self.batch_size):
                filename_o = test_files[i*self.batch_size].split('/')[-1].split('.')[0]
                idx = i+1
                B_imgs = np.reshape(np.array(test_imgs[i]), (self.batch_size,self.image_size, self.image_size,-1))
                print("testing B image %d"%(idx))
                B2A_imgs, B2A2B_imgs = self.sess.run(
                    [self.B2A, self.B2A2B],
                    feed_dict={self.real_B:B_imgs}
                    )
                save_images(B_imgs, [self.batch_size, 1],test_path+filename_o+'_realB.jpg')
                save_images(B2A_imgs, [self.batch_size, 1],test_path+filename_o+'_B2A.jpg')
                save_images(B2A2B_imgs, [self.batch_size, 1],test_path+filename_o+'_B2A2B.jpg')
import argparse
#from model import DualNet
import tensorflow as tf

parser = argparse.ArgumentParser(description='Argument parser')

""" Arguments related to network architecture"""
#parser.add_argument('--network_type', dest='network_type', default='fcn_4', help='fcn_1,fcn_2,fcn_4,fcn_8, fcn_16, fcn_32, fcn_64, fcn_128')
parser.add_argument('--image_size', dest='image_size', type=int, default=256, help='size of input images (applicable to both A images and B images)')
parser.add_argument('--fcn_filter_dim', dest='fcn_filter_dim', type=int, default=64, help='# of fcn filters in first conv layer')
parser.add_argument('--A_channels', dest='A_channels', type=int, default=1, help='# of channels of image A')
parser.add_argument('--B_channels', dest='B_channels', type=int, default=1, help='# of channels of image B')

"""Arguments related to run mode"""
parser.add_argument('--phase', dest='phase', default='train', help='train, test')

"""Arguments related to training"""
parser.add_argument('--loss_metric', dest='loss_metric', default='L1', help='L1, or L2')
parser.add_argument('--niter', dest='niter', type=int, default=30, help='# of iter at starting learning rate')
parser.add_argument('--lr', dest='lr', type=float, default=0.00005, help='initial learning rate for adam')#0.0002
parser.add_argument('--beta1', dest='beta1', type=float, default=0.5, help='momentum term of adam')
parser.add_argument('--flip', dest='flip', type=bool, default=True, help='if flip the images for data argumentation')
parser.add_argument('--dataset_name', dest='dataset_name', default='sketch-photo', help='name of the dataset')
parser.add_argument('--epoch', dest='epoch', type=int, default=50, help='# of epoch')
parser.add_argument('--batch_size', dest='batch_size', type=int, default=1, help='# images in batch')
parser.add_argument('--lambda_A', dest='lambda_A', type=float, default=20.0, help='# weights of A recovery loss')
parser.add_argument('--lambda_B', dest='lambda_B', type=float, default=20.0, help='# weights of B recovery loss')

"""Arguments related to monitoring and outputs"""
parser.add_argument('--save_freq', dest='save_freq', type=int, default=50, help='save the model every save_freq sgd iterations')
parser.add_argument('--checkpoint_dir', dest='checkpoint_dir', default='./checkpoint', help='models are saved here')
parser.add_argument('--sample_dir', dest='sample_dir', default='./sample', help='sample are saved here')
parser.add_argument('--test_dir', dest='test_dir', default='./test', help='test sample are saved here')

args=parser.parse_args([]) #这里是关键,使用notebook的必须加上[],如果在命令行使用 args=parser.parse_args()


def main(_):
    if not os.path.exists(args.checkpoint_dir):
        os.makedirs(args.checkpoint_dir)
    if not os.path.exists(args.sample_dir):
        os.makedirs(args.sample_dir)
    if not os.path.exists(args.test_dir):
        os.makedirs(args.test_dir)

    with tf.Session() as sess:
        model = DualNet(sess, image_size=args.image_size, batch_size=args.batch_size,\
                        dataset_name=args.dataset_name,A_channels = args.A_channels, \
                        B_channels = args.B_channels, flip  = (args.flip == 'True'),\
                        checkpoint_dir=args.checkpoint_dir, sample_dir=args.sample_dir,\
                        fcn_filter_dim = args.fcn_filter_dim,\
                        loss_metric=args.loss_metric, lambda_B=args.lambda_B, \
                        lambda_A= args.lambda_A)

        if args.phase == 'train':
            model.train(args)
        else:
            model.test(args)

if __name__ == '__main__':
    tf.app.run()

 

posted @ 2018-04-27 19:59  白菜hxj  阅读(1104)  评论(0编辑  收藏  举报