Sym-GAN
import sys; sys.path.append("/home/hxj/anaconda3/lib/python3.6/site-packages") from __future__ import print_function import os import matplotlib as mpl import tarfile import matplotlib.image as mpimg from matplotlib import pyplot as plt import cv2 import mxnet as mx from mxnet import gluon from mxnet import ndarray as nd from mxnet.gluon import nn, utils from mxnet.gluon.nn import Dense, Activation, Conv2D, Conv2DTranspose, \ BatchNorm, LeakyReLU, Flatten, HybridSequential, HybridBlock, Dropout from mxnet import autograd import numpy as np epochs = 500 batch_size = 10 use_gpu = True ctx = mx.gpu() if use_gpu else mx.cpu() lr = 0.0002 beta1 = 0.5 #lambda1 = 100 lambda1 = 10 pool_size = 50
img_horizon = mx.image.HorizontalFlipAug(1) def load_retinex(batch_size): img_in_list = [] img_out_list = [] """ path='CAS/Lighting_aligned_128' ground_path = 'CAS/Lighting_aligned_128_retinex_to_color' for path, _, fnames in os.walk(path): for fname in fnames: if not fname.endswith('.png'): continue lingting_img = os.path.join(path, fname) ground_img = os.path.join(ground_path,fname) #补充水平翻转和光照增加或者减少50% img_arr_fname = mx.image.imread(lingting_img).astype(np.float32)/127.5 - 1 img_arr_fname_t = img_horizon(img_arr_fname) img_arr_gnema = mx.image.imread(ground_img).astype(np.float32)/127.5 - 1 img_arr_gnema_t = img_horizon(img_arr_gnema) img_arr_fname = cv2.cvtColor(img_arr_fname.asnumpy(), cv2.COLOR_RGB2LAB) img_arr_fname_t = cv2.cvtColor(img_arr_fname_t.asnumpy(), cv2.COLOR_RGB2LAB) img_arr_gnema = cv2.cvtColor(img_arr_gnema.asnumpy(), cv2.COLOR_RGB2LAB) img_arr_gnema_t = cv2.cvtColor(img_arr_gnema_t.asnumpy(), cv2.COLOR_RGB2LAB) img_arr_in, img_arr_out = [img_arr_fname[:,:,0].reshape((1,) + img_arr_in.shape), img_arr_out.reshape((1,) + img_arr_out.shape)] img_in_list.append(img_arr_in) img_out_list.append(img_arr_out) img_arr_in_t, img_arr_out_t = [nd.transpose(img_arr_fname_t, (2,0,1)), nd.transpose(img_arr_gnema_t, (2,0,1))] img_arr_in_t, img_arr_out_t = [img_arr_in_t.reshape((1,) + img_arr_in_t.shape), img_arr_out_t.reshape((1,) + img_arr_out_t.shape)] img_in_list.append(img_arr_in_t) img_out_list.append(img_arr_out_t) """ mulpath_lighting = 'MultiPIE/MultiPIE_Lighting/' mulpaht_ground = 'MultiPIE/MultiPIE_Lighting/' for path, _, fnames in os.walk(mulpath_lighting): for fname in fnames: num = fname[14:16] if num !='07': lingting_img = os.path.join(mulpath_lighting, fname) ground_img = os.path.join(mulpaht_ground,fname[:14]+'07.png') img_arr_fname = mx.image.imread(lingting_img).astype(np.float32)/127.5 - 1 img_arr_gnema = mx.image.imread(ground_img).astype(np.float32)/127.5 - 1 #img_arr_fname = mx.image.imresize(img_arr_fname,256,256) #img_arr_gnema = mx.image.imresize(img_arr_gnema,256,256) #补充水平翻转和光照增加或者减少50% #img_arr_fname_b = img_bright(img_arr_fname) img_arr_fname_t = img_horizon(img_arr_fname) img_arr_gnema_t = img_horizon(img_arr_gnema) #lighting image 共4个,normal ground truth共2个 img_arr_in, img_arr_out = [nd.transpose(img_arr_fname, (2,0,1)), nd.transpose(img_arr_gnema, (2,0,1))] img_arr_in, img_arr_out = [img_arr_in.reshape((1,) + img_arr_in.shape), img_arr_out.reshape((1,) + img_arr_out.shape)] img_in_list.append(img_arr_in) img_out_list.append(img_arr_out) img_arr_in_t, img_arr_out_t = [nd.transpose(img_arr_fname_t, (2,0,1)), nd.transpose(img_arr_gnema_t, (2,0,1))] img_arr_in_t, img_arr_out_t = [img_arr_in_t.reshape((1,) + img_arr_in_t.shape), img_arr_out_t.reshape((1,) + img_arr_out_t.shape)] img_in_list.append(img_arr_in_t) img_out_list.append(img_arr_out_t) return mx.io.NDArrayIter(data=[nd.concat(*img_in_list,dim=0), nd.concat(*img_out_list,dim=0)],batch_size=batch_size)
img_wd = 256 img_ht = 256 train_img_path = '../data/edges2handbags/train_mini/' val_img_path = '../data/edges2handbags/val/' def load_data(path, batch_size, is_reversed=False): img_in_list = [] img_out_list = [] for path, _, fnames in os.walk(path): for fname in fnames: if not fname.endswith('.jpg'): continue img = os.path.join(path, fname) img_arr = mx.image.imread(img).astype(np.float32)/127.5 - 1 img_arr = mx.image.imresize(img_arr, img_wd * 2, img_ht) # Crop input and output images img_arr_in, img_arr_out = [mx.image.fixed_crop(img_arr, 0, 0, img_wd, img_ht), mx.image.fixed_crop(img_arr, img_wd, 0, img_wd, img_ht)] img_arr_in, img_arr_out = [nd.transpose(img_arr_in, (2,0,1)), nd.transpose(img_arr_out, (2,0,1))] img_arr_in, img_arr_out = [img_arr_in.reshape((1,) + img_arr_in.shape), img_arr_out.reshape((1,) + img_arr_out.shape)] img_in_list.append(img_arr_out if is_reversed else img_arr_in) img_out_list.append(img_arr_in if is_reversed else img_arr_out) return mx.io.NDArrayIter(data=[nd.concat(*img_in_list, dim=0), nd.concat(*img_out_list, dim=0)], batch_size=batch_size) train_data = load_data(train_img_path, batch_size, is_reversed=False) val_data = load_data(val_img_path, batch_size, is_reversed=False)
img_horizon = mx.image.HorizontalFlipAug(1) def load_retinex(batch_size): img_in_list = [] img_out_list = [] path='CAS/Lighting_aligned_128' ground_path = 'CAS/Normal_aligned_128' img_in_list = [] img_out_list = [] """ for path, _, fnames in os.walk(path): for fname in fnames: if not fname.endswith('.png'): continue temp_name = fname[0:9]+'_IEU+00_PM+00_EN_A0_D0_T0_BB_M0_R0_S0.png' ground_img = os.path.join(ground_path, temp_name) if not os.path.exists(ground_img): temp_name = fname[0:9]+'_IEU+00_PM+00_EN_A0_D0_T0_BB_M0_R1_S0.png' ground_img = os.path.join(ground_path, temp_name) if not os.path.exists(ground_img): continue lingting_img = os.path.join(path, fname) #补充水平翻转和光照增加或者减少50% img_arr_fname = mx.image.imread(lingting_img).astype(np.float32)/127.5 - 1 img_arr_fname_t = img_horizon(img_arr_fname) img_arr_gnema = mx.image.imread(ground_img).astype(np.float32)/127.5 - 1 img_arr_gnema_t = img_horizon(img_arr_gnema) img_arr_in, img_arr_out = [nd.transpose(img_arr_fname, (2,0,1)), nd.transpose(img_arr_gnema, (2,0,1))] img_arr_in, img_arr_out = [img_arr_in.reshape((1,) + img_arr_in.shape), img_arr_out.reshape((1,) + img_arr_out.shape)] img_in_list.append(img_arr_in) img_out_list.append(img_arr_out) img_arr_in_t, img_arr_out_t = [nd.transpose(img_arr_fname_t, (2,0,1)), nd.transpose(img_arr_gnema_t, (2,0,1))] img_arr_in_t, img_arr_out_t = [img_arr_in_t.reshape((1,) + img_arr_in_t.shape), img_arr_out_t.reshape((1,) + img_arr_out_t.shape)] img_in_list.append(img_arr_in_t) img_out_list.append(img_arr_out_t) """ mulpath_lighting = 'MultiPIE/MultiPIE_Lighting_128/' mulpaht_ground = 'MultiPIE/MultiPIE_Lighting_128/' for path, _, fnames in os.walk(mulpath_lighting): for fname in fnames: num = fname[14:16] if num !='07': lingting_img = os.path.join(mulpath_lighting, fname) ground_img = os.path.join(mulpaht_ground,fname[:14]+'07.png') img_arr_fname = mx.image.imread(lingting_img).astype(np.float32)/127.5 - 1 img_arr_gnema = mx.image.imread(ground_img).astype(np.float32)/127.5 - 1 #img_arr_fname = mx.image.imresize(img_arr_fname,256,256) #img_arr_gnema = mx.image.imresize(img_arr_gnema,256,256) #补充水平翻转和光照增加或者减少50% #img_arr_fname_b = img_bright(img_arr_fname) img_arr_fname_t = img_horizon(img_arr_fname) img_arr_gnema_t = img_horizon(img_arr_gnema) #lighting image 共4个,normal ground truth共2个 img_arr_in, img_arr_out = [nd.transpose(img_arr_fname, (2,0,1)), nd.transpose(img_arr_gnema, (2,0,1))] img_arr_in, img_arr_out = [img_arr_in.reshape((1,) + img_arr_in.shape), img_arr_out.reshape((1,) + img_arr_out.shape)] img_in_list.append(img_arr_in) img_out_list.append(img_arr_out) img_arr_in_t, img_arr_out_t = [nd.transpose(img_arr_fname_t, (2,0,1)), nd.transpose(img_arr_gnema_t, (2,0,1))] img_arr_in_t, img_arr_out_t = [img_arr_in_t.reshape((1,) + img_arr_in_t.shape), img_arr_out_t.reshape((1,) + img_arr_out_t.shape)] img_in_list.append(img_arr_in_t) img_out_list.append(img_arr_out_t) return mx.io.NDArrayIter(data=[nd.concat(*img_in_list,dim=0), nd.concat(*img_out_list,dim=0)],batch_size=batch_size)
def visualize(img_arr): plt.imshow(((img_arr.asnumpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8)) plt.axis('off') def preview_train_data(train_data): img_in_list, img_out_list = train_data.next().data for i in range(4): plt.subplot(2,4,i+1) visualize(img_in_list[i]) plt.subplot(2,4,i+5) visualize(img_out_list[i]) plt.show() train_data = load_retinex(10) preview_train_data(train_data)
# Define Unet generator skip block class UnetSkipUnit(HybridBlock): def __init__(self, inner_channels, outer_channels, inner_block=None, innermost=False, outermost=False, use_dropout=False, use_bias=False): super(UnetSkipUnit, self).__init__() with self.name_scope(): self.outermost = outermost en_conv = Conv2D(channels=inner_channels, kernel_size=4, strides=2, padding=1, in_channels=outer_channels, use_bias=use_bias) en_relu = LeakyReLU(alpha=0.2) en_norm = BatchNorm(momentum=0.1, in_channels=inner_channels) de_relu = Activation(activation='relu') de_norm = BatchNorm(momentum=0.1, in_channels=outer_channels) if innermost: de_conv = Conv2DTranspose(channels=outer_channels, kernel_size=4, strides=2, padding=1, in_channels=inner_channels, use_bias=use_bias) encoder = [en_relu, en_conv] decoder = [de_relu, de_conv, de_norm] model = encoder + decoder elif outermost: de_conv = Conv2DTranspose(channels=outer_channels, kernel_size=4, strides=2, padding=1, in_channels=inner_channels * 2) encoder = [en_conv] decoder = [de_relu, de_conv, Activation(activation='tanh')] model = encoder + [inner_block] + decoder else: de_conv = Conv2DTranspose(channels=outer_channels, kernel_size=4, strides=2, padding=1, in_channels=inner_channels * 2, use_bias=use_bias) encoder = [en_relu, en_conv, en_norm] decoder = [de_relu, de_conv, de_norm] model = encoder + [inner_block] + decoder if use_dropout: model += [Dropout(rate=0.5)] self.model = HybridSequential() with self.model.name_scope(): for block in model: self.model.add(block) def hybrid_forward(self, F, x): if self.outermost: return self.model(x) else: return F.concat(self.model(x), x, dim=1) # Define Unet generator class UnetGenerator(HybridBlock): def __init__(self, in_channels, num_downs, ngf=64, use_dropout=True): super(UnetGenerator, self).__init__() #Build unet generator structure unet = UnetSkipUnit(ngf * 8, ngf * 8, innermost=True) for _ in range(num_downs - 5): unet = UnetSkipUnit(ngf * 8, ngf * 8, unet, use_dropout=use_dropout) unet = UnetSkipUnit(ngf * 8, ngf * 4, unet) unet = UnetSkipUnit(ngf * 4, ngf * 2, unet) unet = UnetSkipUnit(ngf * 2, ngf * 1, unet) unet = UnetSkipUnit(ngf, in_channels, unet, outermost=True) with self.name_scope(): self.model = unet def hybrid_forward(self, F, x): return self.model(x) # Define the PatchGAN discriminator class Discriminator(HybridBlock): def __init__(self, in_channels, ndf=64, n_layers=3, use_sigmoid=False, use_bias=False): super(Discriminator, self).__init__() with self.name_scope(): self.model = HybridSequential() kernel_size = 4 padding = int(np.ceil((kernel_size - 1)/2)) self.model.add(Conv2D(channels=ndf, kernel_size=kernel_size, strides=2, padding=padding, in_channels=in_channels)) self.model.add(LeakyReLU(alpha=0.2)) nf_mult = 1 for n in range(1, n_layers): nf_mult_prev = nf_mult nf_mult = min(2 ** n, 8) self.model.add(Conv2D(channels=ndf * nf_mult, kernel_size=kernel_size, strides=2, padding=padding, in_channels=ndf * nf_mult_prev, use_bias=use_bias)) self.model.add(BatchNorm(momentum=0.1, in_channels=ndf * nf_mult)) self.model.add(LeakyReLU(alpha=0.2)) nf_mult_prev = nf_mult nf_mult = min(2 ** n_layers, 8) self.model.add(Conv2D(channels=ndf * nf_mult, kernel_size=kernel_size, strides=1, padding=padding, in_channels=ndf * nf_mult_prev, use_bias=use_bias)) self.model.add(BatchNorm(momentum=0.1, in_channels=ndf * nf_mult)) self.model.add(LeakyReLU(alpha=0.2)) self.model.add(Conv2D(channels=1, kernel_size=kernel_size, strides=1, padding=padding, in_channels=ndf * nf_mult)) if use_sigmoid: self.model.add(Activation(activation='sigmoid')) def hybrid_forward(self, F, x): out = self.model(x) #print(out) return out
def param_init(param): if param.name.find('conv') != -1: if param.name.find('weight') != -1: param.initialize(init=mx.init.Normal(0.02), ctx=ctx) else: param.initialize(init=mx.init.Zero(), ctx=ctx) elif param.name.find('batchnorm') != -1: param.initialize(init=mx.init.Zero(), ctx=ctx) # Initialize gamma from normal distribution with mean 1 and std 0.02 if param.name.find('gamma') != -1: param.set_data(nd.random_normal(1, 0.02, param.data().shape)) def network_init(net): with net.name_scope(): for param in net.collect_params().values(): param_init(param) def set_network(): # Pixel2pixel networks netG1 = UnetGenerator(in_channels=3, num_downs=6) netD1 = Discriminator(in_channels=6) netG2 = UnetGenerator(in_channels=3, num_downs=6) netD2 = Discriminator(in_channels=6) # Initialize parameters network_init(netG1) network_init(netD1) network_init(netG2) network_init(netD2) # trainer for the generator and the discriminator trainerG1 = gluon.Trainer(netG1.collect_params(), 'adam', {'learning_rate': lr, 'beta1': beta1}) trainerD1 = gluon.Trainer(netD1.collect_params(), 'adam', {'learning_rate': lr, 'beta1': beta1}) trainerG2 = gluon.Trainer(netG2.collect_params(), 'adam', {'learning_rate': lr, 'beta1': beta1}) trainerD2 = gluon.Trainer(netD2.collect_params(), 'adam', {'learning_rate': lr, 'beta1': beta1}) return netG1, netD1, trainerG1, trainerD1, netG2, netD2, trainerG2, trainerD2 # Loss #GAN_loss = gluon.loss.SigmoidBinaryCrossEntropyLoss() GAN_loss = gluon.loss.L2Loss() L1_loss = gluon.loss.L1Loss() L2_loss = gluon.loss.L2Loss() netG1, netD1, trainerG1, trainerD1, netG2, netD2, trainerG2, trainerD2 = set_network()
class ImagePool(): def __init__(self, pool_size): self.pool_size = pool_size if self.pool_size > 0: self.num_imgs = 0 self.images = [] def query(self, images): if self.pool_size == 0: return images ret_imgs = [] for i in range(images.shape[0]): image = nd.expand_dims(images[i], axis=0) if self.num_imgs < self.pool_size: self.num_imgs = self.num_imgs + 1 self.images.append(image) ret_imgs.append(image) else: p = nd.random_uniform(0, 1, shape=(1,)).asscalar() if p > 0.5: random_id = nd.random_uniform(0, self.pool_size - 1, shape=(1,)).astype(np.uint8).asscalar() tmp = self.images[random_id].copy() self.images[random_id] = image ret_imgs.append(tmp) else: ret_imgs.append(image) ret_imgs = nd.concat(*ret_imgs, dim=0) return ret_imgs
#这是retinex使用的代码
def singleScaleRetinex(img, sigma): retinex = np.log10(img) - np.log10(cv2.GaussianBlur(img, (0, 0), sigma)) return retinex def multiScaleRetinex(img, sigma_list): retinex = np.zeros_like(img) for sigma in sigma_list: retinex += singleScaleRetinex(img, sigma) retinex = retinex / len(sigma_list) return retinex def colorRestoration(img, alpha, beta): img_sum = np.sum(img, axis=2, keepdims=True) color_restoration = beta * (np.log10(alpha * img) - np.log10(img_sum)) return color_restoration def simplestColorBalance(img, low_clip, high_clip): total = img.shape[0] * img.shape[1] for i in range(img.shape[2]): unique, counts = np.unique(img[:, :, i], return_counts=True) current = 0 for u, c in zip(unique, counts): if float(current) / total < low_clip: low_val = u if float(current) / total < high_clip: high_val = u current += c img[:, :, i] = np.maximum(np.minimum(img[:, :, i], high_val), low_val) return img def MSRCR(img, sigma_list, G, b, alpha, beta, low_clip, high_clip): img = np.float64(img) + 1.0 img_retinex = multiScaleRetinex(img, sigma_list) img_color = colorRestoration(img, alpha, beta) img_msrcr = G * (img_retinex * img_color + b) for i in range(img_msrcr.shape[2]): img_msrcr[:, :, i] = (img_msrcr[:, :, i] - np.min(img_msrcr[:, :, i])) / \ (np.max(img_msrcr[:, :, i]) - np.min(img_msrcr[:, :, i])) * \ 255 img_msrcr = np.uint8(np.minimum(np.maximum(img_msrcr, 0), 255)) img_msrcr = simplestColorBalance(img_msrcr, low_clip, high_clip) return img_msrcr def automatedMSRCR(img, sigma_list): img = np.float64(img) + 1.0 img_retinex = multiScaleRetinex(img, sigma_list) for i in range(img_retinex.shape[2]): unique, count = np.unique(np.int32(img_retinex[:, :, i] * 100), return_counts=True) for u, c in zip(unique, count): if u == 0: zero_count = c break low_val = unique[0] / 100.0 high_val = unique[-1] / 100.0 for u, c in zip(unique, count): if u < 0 and c < zero_count * 0.1: low_val = u / 100.0 if u > 0 and c < zero_count * 0.1: high_val = u / 100.0 break img_retinex[:, :, i] = np.maximum(np.minimum(img_retinex[:, :, i], high_val), low_val) img_retinex[:, :, i] = (img_retinex[:, :, i] - np.min(img_retinex[:, :, i])) / \ (np.max(img_retinex[:, :, i]) - np.min(img_retinex[:, :, i])) \ * 255 img_retinex = np.uint8(img_retinex) return img_retinex def MSRCP(img, sigma_list, low_clip, high_clip): img = np.float64(img) + 1.0 intensity = np.sum(img, axis=2) / img.shape[2] retinex = multiScaleRetinex(intensity, sigma_list) intensity = np.expand_dims(intensity, 2) retinex = np.expand_dims(retinex, 2) intensity1 = simplestColorBalance(retinex, low_clip, high_clip) intensity1 = (intensity1 - np.min(intensity1)) / \ (np.max(intensity1) - np.min(intensity1)) * \ 255.0 + 1.0 img_msrcp = np.zeros_like(img) for y in range(img_msrcp.shape[0]): for x in range(img_msrcp.shape[1]): B = np.max(img[y, x]) A = np.minimum(256.0 / B, intensity1[y, x, 0] / intensity[y, x, 0]) img_msrcp[y, x, 0] = A * img[y, x, 0] img_msrcp[y, x, 1] = A * img[y, x, 1] img_msrcp[y, x, 2] = A * img[y, x, 2] img_msrcp = np.uint8(img_msrcp - 1.0) return img_msrcp
#预训练
from datetime import datetime import time import logging def facc(label, pred): pred = pred.ravel() label = label.ravel() return ((pred > 0.5) == label).mean() def pre_train(): metric = mx.metric.CustomMetric(facc) stamp = datetime.now().strftime('%Y_%m_%d-%H_%M') logging.basicConfig(level=logging.DEBUG) for epoch in range(epochs): tic = time.time() btic = time.time() train_data.reset() iter = 0 for batch in train_data: ############################ # (1) Update D network: maximize log(D(x, y)) + log(1 - D(x, G(x, z))) ########################### real_in = batch.data[0].as_in_context(ctx) real_out = batch.data[1].as_in_context(ctx) with autograd.record(): fake_out = netG1(real_in) errG1 = L1_loss(fake_out, real_out)*lambda1 #errG1 = land_mark_errs(real_in, fake_out) errG1.backward() trainerG1.step(batch.data[0].shape[0]) with autograd.record(): fake_out2 = netG2(real_out) errG2 = L1_loss(fake_out2, real_in)*lambda1 errG2.backward() trainerG2.step(batch.data[0].shape[0]) # Print log infomation every ten batches if iter % 10 == 0: name, acc = metric.get() logging.info('speed: {} samples/s'.format(batch_size / (time.time() - btic))) logging.info('G1generator1 loss = %f, binary training acc = %f at iter %d epoch %d' %(nd.mean(errG1).asscalar(), acc, iter, epoch)) logging.info('G1generator2 loss = %f, binary training acc = %f at iter %d epoch %d' %(nd.mean(errG2).asscalar(), acc, iter, epoch)) iter = iter + 1 btic = time.time() name, acc = metric.get() metric.reset() logging.info('\nbinary training acc at epoch %d: %s=%f' % (epoch, name, acc)) logging.info('time: %f' % (time.time() - tic)) # Visualize one generated image for each epoch fake_img = fake_out[0] visualize(fake_img) plt.show() #fake_img2 = fake_out2[0] #visualize(fake_img2) #plt.show() pre_train()
def save_data(path,tpath): img_in_list = [] img_out_list = [] for path, _, fnames in os.walk(path): for fname in fnames: if not fname.endswith('.jpg'): continue img = os.path.join(path, fname) img_arr = mx.image.imread(img).astype(np.float32)/127.5 - 1 img_arr = mx.image.imresize(img_arr, img_wd * 2, img_ht) # Crop input and output images img_arr_in, img_arr_out = [mx.image.fixed_crop(img_arr, 0, 0, img_wd, img_ht), mx.image.fixed_crop(img_arr, img_wd, 0, img_wd, img_ht)] #img_arr_in = mx.image.imresize(img_arr_in,128,128) #img_arr_out = mx.image.imresize(img_arr_out,128,128) img_arr_in, img_arr_out = [nd.transpose(img_arr_in, (2,0,1)), nd.transpose(img_arr_out, (2,0,1))] img_arr_in, img_arr_out = [img_arr_in.reshape((1,) + img_arr_in.shape), img_arr_out.reshape((1,) + img_arr_out.shape)] img_out = netG1(img_arr_out.as_in_context(ctx)) img_out1 = img_out[0] img_out2 = ((img_out1.asnumpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8) save_name = tpath+fname cv2.imwrite(save_name, img_out2) save_data("../data/edges2handbags/val/","../data/edges2handbags/G1andG2/")
netD1 = Discriminator(in_channels=6) netD2 = Discriminator(in_channels=6) network_init(netD1) network_init(netD2) trainerD1 = gluon.Trainer(netD1.collect_params(), 'adam', {'learning_rate': lr, 'beta1': beta1}) trainerD2 = gluon.Trainer(netD2.collect_params(), 'adam', {'learning_rate': lr, 'beta1': beta1})
from datetime import datetime import time import logging def facc(label, pred): pred = pred.ravel() label = label.ravel() return ((pred > 0.5) == label).mean() def dual_pre_train(): metric = mx.metric.CustomMetric(facc) stamp = datetime.now().strftime('%Y_%m_%d-%H_%M') logging.basicConfig(level=logging.DEBUG) for epoch in range(epochs): tic = time.time() btic = time.time() PIE_normal_to_lighting.reset() iter = 0 for (batch1, batch2) in zip(retinex_data,PIE_normal_to_lighting): ############################ # (1) Update D network: maximize log(D(x, y)) + log(1 - D(x, G(x, z))) ########################### real_in = batch1.data[0].as_in_context(ctx) real_out = batch1.data[1].as_in_context(ctx) lighing_bad = batch2.data[0].as_in_context(ctx) lighing_good = batch2.data[1].as_in_context(ctx) with autograd.record(): fake_out = netG1(real_in) #errG1 = L1_loss(real_out, fake_out) + L1_loss(netG1(netG2(fake_out)),real_out) errG1 = L1_loss(real_in, fake_out)+L1_loss(netG1(netG2(lighing_good)), lighing_good) #增加一个三方loss #errG1 = L1_loss(real_out, fake_out) + L1_loss(netG1(netG2(fake_out)),real_out) # + L1_loss(netG1(netG2(fake_out)),fake_out) errG1.backward() trainerG1.step(batch1.data[0].shape[0]) with autograd.record(): fake_out3 = netG2(real_out) #errG2 = L1_loss(real_in, fake_out3) + L1_loss(netG2(netG1(fake_out3)),real_in) errG2 = L1_loss(lighing_good, fake_out3)+L1_loss(netG2(netG1(real_in)), real_in) #增加一个三方loss #errG2 = L1_loss(real_in, fake_out3) + L1_loss(netG2(netG1(fake_out3)),real_in) # + L1_loss(netG2(netG1(fake_out3)),fake_out3) errG2.backward() trainerG2.step(batch2.data[0].shape[0]) # Print log infomation every ten batches if iter % 10 == 0: name, acc = metric.get() logging.info('speed: {} samples/s'.format(batch_size / (time.time() - btic))) logging.info('G1generator loss = %f, binary training acc = %f at iter %d epoch %d' %(nd.mean(errG1).asscalar(), acc, iter, epoch)) logging.info('G2generator loss = %f, binary training acc = %f at iter %d epoch %d' %(nd.mean(errG2).asscalar(), acc, iter, epoch)) iter = iter + 1 btic = time.time() name, acc = metric.get() metric.reset() logging.info('\nbinary training acc at epoch %d: %s=%f' % (epoch, name, acc)) logging.info('time: %f' % (time.time() - tic)) # Visualize one generated image for each epoch fake_img = fake_out[0] visualize(fake_img) plt.show() dual_pre_train()
def test_netG(Spath,Tpath): for path, _, fnames in os.walk(Spath): for fname in fnames: if not fname.endswith('.png'): continue #num = fname[14:16] #if num !='07': #continue test_img = os.path.join(path, fname) img_fname = mx.image.imread(test_img) img_arr_fname = img_fname.astype(np.float32)/127.5 - 1 img_arr_fname = mx.image.imresize(img_arr_fname,128,128) img_arr_in = nd.transpose(img_arr_fname, (2,0,1)) img_arr_in = img_arr_in.reshape((1,) + img_arr_in.shape) img_out = netG1(img_arr_in.as_in_context(ctx)) img_out = img_out[0] #img_out = mx.image.imresize(img_out,120,165) save_name = Tpath+ fname plt.imsave(save_name, ((img_out.asnumpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8) ) #test_netG('MultiPIE/MultiPIE_test_128_Gray/','MultiPIE/relighting/') test_netG('MultiPIE/Bio_relighing2/','MultiPIE/Bio_color/')
#使用opencv的人脸特征点作为损失
fileDir = '/home/hxj/gluon-tutorials/GAN/openface/' sys.path.append(os.path.join(fileDir)) import argparse import cv2 import dlib import matplotlib.pyplot as plt from pylab import plot from openface.align_dlib import AlignDlib modelDir = os.path.join(fileDir, 'models') openfaceModelDir = os.path.join(modelDir, 'openface') dlibModelDir = os.path.join(modelDir, 'dlib') dlibFacePredictor= os.path.join(dlibModelDir, "shape_predictor_68_face_landmarks.dat") def land_mark_errs(batch1,batch2): align = AlignDlib(dlibFacePredictor) sum_err = nd.zeros((10)).as_in_context(ctx) i=0 for (x,y) in zip(batch1,batch2): x1 = ((x.asnumpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8) y1 = ((y.asnumpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8) """ bbx = align.getLargestFaceBoundingBox(x1) if bbx is None: x1_r = MSRCR(x1,[15, 80, 250], 5.0, 25.0, 125.0, 46.0, 0.01, 0.99) bbx = align.getLargestFaceBoundingBox(x1_r) if bbx is None: lab = cv2.cvtColor(x1,cv2.COLOR_RGB2LUV) bbx = align.getLargestFaceBoundingBox(lab) #if bbx is None: #print('bbx is none') bby = align.getLargestFaceBoundingBox(y1) if bby is None: y1_r = MSRCR(y1,[15, 80, 250], 5.0, 25.0, 125.0, 46.0, 0.01, 0.99) bby = align.getLargestFaceBoundingBox(y1_r) if bby is None: lab = cv2.cvtColor(y1, cv2.COLOR_RGB2LUV) bby = align.getLargestFaceBoundingBox(lab) if bby is None: #print('bby is none') if bby is None: continue if bbx is None: #bbx= bby continue """ bbx = dlib.rectangle(-19, -19, 124, 125) bby = dlib.rectangle(-19, -19, 124, 125) landmarks_x = nd.array(align.findLandmarks(x1, bbx)) landmarks_y = nd.array(align.findLandmarks(y1, bby)) if landmarks_x is None: continue if landmarks_y is None: continue sum_err[i]=nd.sum(nd.abs(landmarks_x -landmarks_y))/68 i+=1 return sum_err
from datetime import datetime import time import logging def facc(label, pred): pred = pred.ravel() label = label.ravel() return ((pred > 0.5) == label).mean() def generate_train_single(): image_pool = ImagePool(pool_size) metric = mx.metric.CustomMetric(facc) stamp = datetime.now().strftime('%Y_%m_%d-%H_%M') logging.basicConfig(level=logging.DEBUG) for epoch in range(epochs): tic = time.time() btic = time.time() train_data.reset() iter = 0 for batch1 in train_data: #for batch in range(400): ############################ # (1) Update D network: maximize log(D(x, y)) + log(1 - D(x, G(x, z))) ########################### real_in = batch1.data[0].as_in_context(ctx) #将train数据的输入和输出调出来 real_out = batch1.data[1].as_in_context(ctx) #G1 fake_out = netG1(real_in) with autograd.record(): errG1 = L1_loss(real_out, fake_out)* 20 + L1_loss(netG1(netG2(real_out)), real_out) *10 #land_mark_errs(real_in, fake_out)*0.4 errG1.backward() trainerG1.step(batch1.data[0].shape[0]) #G2 fake_out2 = netG2(real_out) with autograd.record(): errG2 = L1_loss(real_in, fake_out2)* 20 + L1_loss(netG2(netG1(real_in)), real_in)*10 #land_mark_errs(real_out, fake_out2)*0.4 trainerG2.step(batch1.data[0].shape[0]) # Print log infomation every ten batches if iter % 10 == 0: name, acc = metric.get() logging.info('speed: {} samples/s'.format(batch_size / (time.time() - btic))) logging.info('generator1 loss = %f, binary training acc = %f at iter %d epoch %d' %(nd.mean(errG1).asscalar(), acc, iter, epoch)) logging.info('generator2 loss = %f, binary training acc = %f at iter %d epoch %d' %(nd.mean(errG2).asscalar(), acc, iter, epoch)) iter = iter + 1 btic = time.time() name, acc = metric.get() metric.reset() logging.info('\nbinary training acc at epoch %d: %s=%f' % (epoch, name, acc)) logging.info('time: %f' % (time.time() - tic)) # Visualize one generated image for each epoch fake_img = fake_out[0] visualize(fake_img) plt.show() generate_train_single()
from skimage import io bgrImg = cv2.imread('CAS/test_aligned_128/FM_000046_IFD+90_PM+00_EN_A0_D0_T0_BW_M0_R1_S0.png') rgbImg = cv2.cvtColor(bgrImg, cv2.COLOR_BGR2RGB) plt.imshow(rgbImg) plt.show() lab = cv2.cvtColor(bgrImg, cv2.COLOR_BGR2LAB) plt.imshow(lab) plt.show() img_test = lab[:,:,0].astype(np.float32)/127.5 - 1 img_test = nd.array(img_test) img_arr_in= img_test.reshape((1,1,) + img_test.shape).as_in_context(ctx) test1 = netG1(img_arr_in) test2 = test1[0][0] cv2.imshow(((test2.asnumpy() + 1.0) * 127.5).astype(np.uint8))
from datetime import datetime import time import logging def facc(label, pred): pred = pred.ravel() label = label.ravel() return ((pred > 0.5) == label).mean() def Dual_train_single(): image_pool = ImagePool(pool_size) metric = mx.metric.CustomMetric(facc) stamp = datetime.now().strftime('%Y_%m_%d-%H_%M') logging.basicConfig(level=logging.DEBUG) for epoch in range(epochs): tic = time.time() btic = time.time() train_data.reset() iter = 0 for batch1 in train_data: #for batch in range(400): ############################ # (1) Update D network: maximize log(D(x, y)) + log(1 - D(x, G(x, z))) ########################### real_in = batch1.data[0].as_in_context(ctx) #将train数据的输入和输出调出来 real_out = batch1.data[1].as_in_context(ctx) #D1 fake_out = netG1(real_in) fake_concat = image_pool.query(nd.concat(real_in, fake_out, dim=1)) with autograd.record(): output = netD1(fake_concat) #output = netD1(fake_out) fake_label = nd.zeros(output.shape, ctx=ctx) errD_fake = GAN_loss(output, fake_label) metric.update([fake_label,], [output,]) # Train with real image real_concat = image_pool.query(nd.concat(real_in, real_out, dim=1)) #ground truth 也要经过G1 output = netD1(real_concat) real_label = nd.ones(output.shape, ctx=ctx) errD_real = GAN_loss(output, real_label) errD1 = (errD_real + errD_fake) *0.5 errD1.backward() metric.update([real_label,], [output,]) trainerD1.step(batch1.data[0].shape[0]) #G1 with autograd.record(): #fake_out = netG1(real_in) fake_concat = image_pool.query(nd.concat(real_in, fake_out, dim=1)) output = netD1(fake_concat) #output = netD1(fake_out) real_label = nd.ones(output.shape, ctx=ctx) #errG1 = GAN_loss(output, real_label) + L1_loss(real_out, fake_out) * lambda1+ \ #L1_loss(netG2(netG1(real_in)), real_in) * lambda1 #errG1 = GAN_loss(output, real_label) + L1_loss(real_in, fake_out) * lambda1+ \ #L1_loss(netG1(netG2(fake_out)), fake_out) * lambda1 errG1 = GAN_loss(output, real_label) + L1_loss(real_out, fake_out) * 20+ \ L1_loss(netG1(netG2(real_out)), real_out) *10 #land_mark_errs(real_out, fake_out) errG1.backward() trainerG1.step(batch1.data[0].shape[0]) #D2 fake_out2 = netG2(real_out) fake_concat2 = image_pool.query(nd.concat(real_out, fake_out2, dim=1)) with autograd.record(): output2 = netD2(fake_concat2) fake_label2 = nd.zeros(output2.shape, ctx=ctx) errD_fake2 = GAN_loss(output2, fake_label2) metric.update([fake_label2,], [output2,]) # Train with real image real_concat2 = image_pool.query(nd.concat(real_out, real_in, dim=1)) output2 = netD2(real_concat2) real_label2 = nd.ones(output2.shape, ctx=ctx) errD_real2 = GAN_loss(output2, real_label2) errD2 = (errD_real2 + errD_fake2) * 0.5 errD2.backward() metric.update([real_label2,], [output2,]) trainerD2.step(batch1.data[0].shape[0]) #G2 with autograd.record(): #fake_out2 = netG2(real_out) fake_concat2 = image_pool.query(nd.concat(real_out, fake_out2, dim=1)) output2 = netD2(fake_concat2) real_label2 = nd.ones(output2.shape, ctx=ctx) #errG2 = GAN_loss(output2, real_label2)+ L1_loss(real_in, fake_out2) * lambda1+ \ #L1_loss(netG1(netG2(lighing_good)), lighing_good) * lambda1 errG2 = GAN_loss(output2, real_label2)+ L1_loss(real_in, fake_out2) * 20+ \ L1_loss(netG2(netG1(real_in)), real_in) *10 #land_mark_errs(real_in, fake_out2) errG2.backward() trainerG2.step(batch1.data[0].shape[0]) # Print log infomation every ten batches if iter % 10 == 0: name, acc = metric.get() logging.info('speed: {} samples/s'.format(batch_size / (time.time() - btic))) logging.info('discriminator1 loss = %f, generator1 loss = %f, binary training acc = %f at iter %d epoch %d' %(nd.mean(errD1).asscalar(), nd.mean(errG1).asscalar(), acc, iter, epoch)) logging.info('discriminator2 loss = %f, generator2 loss = %f, binary training acc = %f at iter %d epoch %d' %(nd.mean(errD2).asscalar(), nd.mean(errG2).asscalar(), acc, iter, epoch)) iter = iter + 1 btic = time.time() name, acc = metric.get() metric.reset() logging.info('\nbinary training acc at epoch %d: %s=%f' % (epoch, name, acc)) logging.info('time: %f' % (time.time() - tic)) # Visualize one generated image for each epoch fake_img = fake_out[0] visualize(fake_img) plt.show() Dual_train_single()
from datetime import datetime import time import logging def facc(label, pred): pred = pred.ravel() label = label.ravel() return ((pred > 0.5) == label).mean() def train(): #image_pool = ImagePool(pool_size) metric = mx.metric.CustomMetric(facc) stamp = datetime.now().strftime('%Y_%m_%d-%H_%M') logging.basicConfig(level=logging.DEBUG) for epoch in range(epochs): tic = time.time() btic = time.time() retinex_data.reset() PIE_normal_to_lighting.reset() iter = 0 for (batch1, batch2) in zip(retinex_data,PIE_normal_to_lighting): #for batch in range(400): ############################ # (1) Update D network: maximize log(D(x, y)) + log(1 - D(x, G(x, z))) ########################### real_in = batch1.data[0].as_in_context(ctx) #将train数据的输入和输出调出来 real_out = batch1.data[1].as_in_context(ctx) lighing_bad = batch2.data[0].as_in_context(ctx) lighing_good = batch2.data[1].as_in_context(ctx) fake_out = netG1(real_in) #D1 with autograd.record(): #fake_concat = image_pool.query(nd.concat(real_in, fake_out, dim=1)) #output = netD1(fake_concat) output = netD1(fake_out) fake_label = nd.zeros(output.shape, ctx=ctx) errD_fake = GAN_loss(output, fake_label) metric.update([fake_label,], [output,]) # Train with real image #real_concat = image_pool.query(nd.concat(real_in, lighing_good, dim=1)) output = netD1(lighing_good) real_label = nd.ones(output.shape, ctx=ctx) errD_real = GAN_loss(output, real_label) errD1 = (errD_real + errD_fake) * 0.5 errD1.backward() metric.update([real_label,], [output,]) trainerD1.step(batch1.data[0].shape[0]) #G1 with autograd.record(): #fake_concat = image_pool.query(nd.concat(real_in, fake_out, dim=1)) #output = netD1(fake_concat) fake_out = netG1(real_in) output = netD1(fake_out) real_label = nd.ones(output.shape, ctx=ctx) #errG1 = GAN_loss(output, real_label) + L1_loss(real_out, fake_out) * lambda1+ \ #L1_loss(netG2(netG1(real_in)), real_in) * lambda1 #errG1 = GAN_loss(output, real_label) + L1_loss(real_in, fake_out) * lambda1+ \ #L1_loss(netG1(netG2(fake_out)), fake_out) * lambda1 errG1 = GAN_loss(output, real_label) + L1_loss(real_in, fake_out) * lambda1+ \ L1_loss(netG1(netG2(lighing_good)), lighing_good) * lambda1 errG1.backward() trainerG1.step(batch1.data[0].shape[0]) #D2 fake_out2 = netG2(lighing_good) with autograd.record(): #fake_concat2 = image_pool.query(nd.concat(lighing_good, fake_out2, dim=1)) output2 = netD2(fake_out2) fake_label2 = nd.zeros(output2.shape, ctx=ctx) errD_fake2 = GAN_loss(output2, fake_label2) metric.update([fake_label2,], [output2,]) # Train with real image #real_concat2 = image_pool.query(nd.concat(lighing_good, real_in, dim=1)) output2 = netD2(real_in) real_label2 = nd.ones(output2.shape, ctx=ctx) errD_real2 = GAN_loss(output2, real_label2) errD2 = (errD_real2 + errD_fake2) * 0.5 errD2.backward() metric.update([real_label2,], [output2,]) trainerD2.step(batch2.data[0].shape[0]) #G2 with autograd.record(): fake_out2 = netG2(lighing_good) #fake_concat2 = image_pool.query(nd.concat(lighing_good, fake_out2, dim=1)) output2 = netD2(fake_out2) real_label2 = nd.ones(output2.shape, ctx=ctx) #errG2 = GAN_loss(output2, real_label2)+ L1_loss(real_in, fake_out2) * lambda1+ \ #L1_loss(netG1(netG2(lighing_good)), lighing_good) * lambda1 errG2 = GAN_loss(output2, real_label2)+ L1_loss(lighing_good, fake_out2) * lambda1+ \ L1_loss(netG2(netG1(real_in)), real_in) * lambda1 errG2.backward() trainerG2.step(batch2.data[0].shape[0]) # Print log infomation every ten batches if iter % 10 == 0: name, acc = metric.get() logging.info('speed: {} samples/s'.format(batch_size / (time.time() - btic))) logging.info('discriminator1 loss = %f, generator1 loss = %f, binary training acc = %f at iter %d epoch %d' %(nd.mean(errD1).asscalar(), nd.mean(errG1).asscalar(), acc, iter, epoch)) logging.info('discriminator2 loss = %f, generator2 loss = %f, binary training acc = %f at iter %d epoch %d' %(nd.mean(errD2).asscalar(), nd.mean(errG2).asscalar(), acc, iter, epoch)) iter = iter + 1 btic = time.time() name, acc = metric.get() metric.reset() logging.info('\nbinary training acc at epoch %d: %s=%f' % (epoch, name, acc)) logging.info('time: %f' % (time.time() - tic)) # Visualize one generated image for each epoch fake_img = fake_out[0] visualize(fake_img) plt.show() train()