加标签的starGAN(MultiPIE数据)
import sys sys.path.append("/home/hxj/anaconda3/lib/python3.6/site-packages") from torchvision.datasets import ImageFolder from PIL import Image import torch import os import random c_dim=5 # dimension of domain labels (1st dataset) c2_dim=8 # dimension of domain labels (2nd dataset) celeba_crop_size=178 # crop size for the CelebA dataset rafd_crop_size=256 #crop size for the RaFD dataset image_size=128 #image resolution g_conv_dim=64 # number of conv filters in the first layer of G d_conv_dim=64 # number of conv filters in the first layer of D g_repeat_num = 6 #number of residual blocks in G d_repeat_num=6 #number of strided conv layers in D lambda_cls=1 #weight for domain classification loss lambda_rec=10 # weight for reconstruction loss lambda_gp=10 #'weight for gradient penalty # Training configuration. dataset='CelebA' # choices=['CelebA', 'RaFD', 'Both']) batch_size=16 # 'mini-batch size num_iters=200000 #number of total iterations for training D num_iters_decay=100000 #number of iterations for decaying lr g_lr=0.0001 #learning rate for G d_lr=0.0001 #learning rate for D n_critic=5 #number of D updates per each G update beta1=0.5 #beta1 for Adam optimizer beta2=0.999 #beta2 for Adam optimizer resume_iters=None #resume training from this step selected_attrs=['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young'] #selected attributes for the CelebA dataset' # Test configuration. test_iters=200000 #test model from this step # Miscellaneous. num_workers=1 mode='test' # choices=['train', 'test']) use_tensorboard=True # Directories. celeba_image_dir='../data/CelebA_nocrop/images/' if mode == 'train' else '../test/test/' attr_path='../data/list_attr_celeba.txt' if mode == 'train' else '../test/test_celeba.txt' rafd_image_dir='../RaFD/train' log_dir='../test/logs' model_save_dir='../stargan/models' sample_dir='../test/samples' result_dir='../test/result' # Step size. log_step=10 sample_step=1000 model_save_step=10000 lr_update_step=1000
import tensorflow as tf #这是加载TensorBord class Logger(object): """Tensorboard logger.""" def __init__(self, log_dir): """Initialize summary writer.""" self.writer = tf.summary.FileWriter(log_dir) def scalar_summary(self, tag, value, step): """Add scalar summary.""" summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]) self.writer.add_summary(summary, step)
from torch.utils import data from torchvision import transforms as T class MulPIE(data.Dataset): """Dataset class for the MultiPIE dataset.""" def __init__(self, image_dir, transform, mode): """Initialize and preprocess the CelebA dataset.""" self.image_dir = image_dir self.transform = transform self.mode = mode self.train_dataset = [] self.test_dataset = [] self.preprocess() if mode == 'train': self.num_images = len(self.train_dataset) else: self.num_images = len(self.test_dataset) def preprocess(self): dataset1 = ImageFolder(self.image_dir, self.transform) for names in dataset1.imgs: lable=[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0] name = names[0] num = int(name[-6:-4]) lable[num]=1 label_trg = [0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0] self.train_dataset.append([name, lable, label_trg]) self.test_dataset.append([name, lable, label_trg]) #该方法是继承torch里面的utils文件夹里面data文件夹里面的Dataset类 def __getitem__(self, index): """Return one image and its corresponding attribute label.""" dataset = self.train_dataset if self.mode == 'train' else self.test_dataset filename, label, label_trg= dataset[index] #filename, label = dataset[index] image = Image.open(os.path.join(self.image_dir, filename)) return self.transform(image), torch.FloatTensor(label), torch.FloatTensor(label_trg) #return self.transform(image), torch.FloatTensor(label) def __len__(self): """Return the number of images.""" return self.num_images def get_loader(image_dir, image_size=128, batch_size=16, mode='train', num_workers=1): """Build and return a data loader.""" transform = [] if mode == 'train': transform.append(T.RandomHorizontalFlip()) #transform1.append(T.CenterCrop(178)) 以后会用到裁剪图像 #to run only once transform.append(T.Resize(image_size)) transform.append(T.ToTensor()) transform.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))) transform = T.Compose(transform) dataset = MulPIE(image_dir, transform, mode) data_loader = data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=(mode=='train'), num_workers=num_workers) return data_loader image_dir='../RaFD/' mul_loader = get_loader(image_dir,image_size,batch_size, mode, num_workers)
import torch.nn as nn import torch.nn.functional as F import numpy as np class ResidualBlock(nn.Module): """Residual Block with instance normalization.""" def __init__(self, dim_in, dim_out): super(ResidualBlock, self).__init__() self.main = nn.Sequential( nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1, bias=False), nn.InstanceNorm2d(dim_out, affine=True, track_running_stats=True), nn.ReLU(inplace=True), nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=1, padding=1, bias=False), nn.InstanceNorm2d(dim_out, affine=True, track_running_stats=True)) def forward(self, x): return x + self.main(x) class Generator(nn.Module): """Generator network.""" def __init__(self, conv_dim=64, c_dim=5, repeat_num=6): super(Generator, self).__init__() layers = [] # 第一个卷积层,输入为图像和label的串联,3表示图像为3通道,c_dim为label的维度, layers.append(nn.Conv2d(3+c_dim, conv_dim, kernel_size=7, stride=1, padding=3, bias=False)) layers.append(nn.InstanceNorm2d(conv_dim, affine=True, track_running_stats=True)) layers.append(nn.ReLU(inplace=True)) # Down-sampling layers. curr_dim = conv_dim #这时候的64个维度 for i in range(2): layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1, bias=False)) layers.append(nn.InstanceNorm2d(curr_dim*2, affine=True, track_running_stats=True)) layers.append(nn.ReLU(inplace=True)) curr_dim = curr_dim * 2 #经过两次循环,这时 curr_dim 的维度为256 # Bottleneck layers. for i in range(repeat_num): layers.append(ResidualBlock(dim_in=curr_dim, dim_out=curr_dim)) # Up-sampling layers. for i in range(2): layers.append(nn.ConvTranspose2d(curr_dim, curr_dim//2, kernel_size=4, stride=2, padding=1, bias=False)) layers.append(nn.InstanceNorm2d(curr_dim//2, affine=True, track_running_stats=True)) layers.append(nn.ReLU(inplace=True)) curr_dim = curr_dim // 2 #最后的维度为3维 layers.append(nn.Conv2d(curr_dim, 3, kernel_size=7, stride=1, padding=3, bias=False)) layers.append(nn.Tanh()) self.main = nn.Sequential(*layers) def forward(self, x, c): #定义计算的过程 # Replicate spatially and concatenate domain information. c = c.view(c.size(0), c.size(1), 1, 1) #view 相当于Numpy中的reshape c = c.repeat(1, 1, x.size(2), x.size(3)) #沿着指定的维度重复tensor x = torch.cat([x, c], dim=1) #将输入图像x,label向量c,串联 return self.main(x) class Discriminator(nn.Module): """Discriminator network with PatchGAN.""" def __init__(self, image_size=128, conv_dim=64, c_dim=5, repeat_num=6): super(Discriminator, self).__init__() layers = [] layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1)) layers.append(nn.LeakyReLU(0.01)) curr_dim = conv_dim for i in range(1, repeat_num): layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1)) layers.append(nn.LeakyReLU(0.01)) curr_dim = curr_dim * 2 kernel_size = int(image_size / np.power(2, repeat_num)) self.main = nn.Sequential(*layers) #将层加入到神经网络 self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=3, stride=1, padding=1, bias=False)#D判读图像的真假 self.conv2 = nn.Conv2d(curr_dim, c_dim, kernel_size=kernel_size, bias=False)#判别输入图像的label. def forward(self, x): h = self.main(x) #这里的X表示训练时的图像,经过main()后生成2048维数据 out_src = self.conv1(h) #out_src 表示图像的真假 out_cls = self.conv2(h) # out_cls 表示图像的标签 return out_src, out_cls.view(out_cls.size(0), out_cls.size(1))
from torchvision.utils import save_image import time import datetime class Multi_Solver(object): """Solver for training and testing StarGAN.""" def __init__(self, mul_loader): """Initialize configurations.""" # Data loader. self.celeba_loader = mul_loader # Model configurations. self.c_dim = 20 self.c2_dim = c2_dim self.image_size = image_size self.g_conv_dim = g_conv_dim self.d_conv_dim = d_conv_dim self.g_repeat_num = g_repeat_num self.d_repeat_num = d_repeat_num self.lambda_cls = lambda_cls self.lambda_rec = lambda_rec self.lambda_gp = lambda_gp # Training configurations. self.dataset = dataset self.batch_size = batch_size self.num_iters = num_iters self.num_iters_decay = num_iters_decay self.g_lr = g_lr self.d_lr = d_lr self.n_critic = n_critic self.beta1 = beta1 self.beta2 = beta2 self.resume_iters = resume_iters self.selected_attrs = selected_attrs # Test configurations. self.test_iters = test_iters # Miscellaneous. self.use_tensorboard = use_tensorboard self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') #self.device = torch.device('cpu') # Directories. self.log_dir = log_dir self.sample_dir = sample_dir self.model_save_dir = model_save_dir self.result_dir = result_dir # Step size. self.log_step = log_step self.sample_step = sample_step self.model_save_step = model_save_step self.lr_update_step = lr_update_step # Build the model and tensorboard. self.build_model() if self.use_tensorboard: self.build_tensorboard() def build_model(self): """Create a generator and a discriminator.""" if self.dataset in ['CelebA', 'RaFD']: self.G = Generator(self.g_conv_dim, self.c_dim, self.g_repeat_num) self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim, self.d_repeat_num) elif self.dataset in ['Both']: self.G = Generator(self.g_conv_dim, self.c_dim+self.c2_dim+2, self.g_repeat_num) # 2 for mask vector. self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim+self.c2_dim, self.d_repeat_num) self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2]) self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr, [self.beta1, self.beta2]) #打印网络结构 #self.print_network(self.G, 'G') #self.print_network(self.D, 'D') self.G.to(self.device) self.D.to(self.device) def denorm(self, x): """Convert the range from [-1, 1] to [0, 1].""" out = (x + 1) / 2 return out.clamp_(0, 1) def build_tensorboard(self): self.logger = Logger(self.log_dir) def restore_model(self, resume_iters): """Restore the trained generator and discriminator.""" print('Loading the trained models from step {}...'.format(resume_iters)) G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(resume_iters)) D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(resume_iters)) self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage)) self.D.load_state_dict(torch.load(D_path, map_location=lambda storage, loc: storage)) def update_lr(self, g_lr, d_lr): """Decay learning rates of the generator and discriminator.""" for param_group in self.g_optimizer.param_groups: param_group['lr'] = g_lr for param_group in self.d_optimizer.param_groups: param_group['lr'] = d_lr def reset_grad(self): """Reset the gradient buffers.""" self.g_optimizer.zero_grad() self.d_optimizer.zero_grad() def classification_loss(self, logit, target, dataset='CelebA'): """Compute binary or softmax cross entropy loss.""" if dataset == 'CelebA': return F.binary_cross_entropy_with_logits(logit, target, size_average=False) / logit.size(0) elif dataset == 'RaFD': return F.cross_entropy(logit, target) def gradient_penalty(self, y, x): """Compute gradient penalty: (L2_norm(dy/dx) - 1)**2.""" weight = torch.ones(y.size()).to(self.device) dydx = torch.autograd.grad(outputs=y, inputs=x, grad_outputs=weight, retain_graph=True, create_graph=True, only_inputs=True)[0] dydx = dydx.view(dydx.size(0), -1) dydx_l2norm = torch.sqrt(torch.sum(dydx**2, dim=1)) return torch.mean((dydx_l2norm-1)**2) def train(self): """Train StarGAN within a single dataset.""" # Set data loader. data_loader = self.celeba_loader data_iter = iter(data_loader) # Learning rate cache for decaying. g_lr = self.g_lr d_lr = self.d_lr x_fixed, c_org,c_trg = next(data_iter) x_fixed = x_fixed.to(self.device) # Start training from scratch or resume training. start_iters = 0 if self.resume_iters: #参数resume_iters 设置为none start_iters = self.resume_iters #可以不连续训练,从之前训练好后的结果处开始 self.restore_model(self.resume_iters) # Start training. print('Start training...') start_time = time.time() for i in range(start_iters, self.num_iters): # =================================================================================== # # 1. Preprocess input data # # =================================================================================== # # Fetch real images and labels. try: x_real, label_org, label_trg = next(data_iter) except: data_iter = iter(data_loader) x_real, label_org, label_trg= next(data_iter) # Generate target domain labels randomly. #rand_idx = torch.randperm(label_org.size(0)) #取出真实标签,然后在真实标签里面随机选取一个 #label_trg = label_org[rand_idx] # 真实label中任意选取一个标签 c_org = label_org.clone() c_trg = label_trg.clone() x_real = x_real.to(self.device) # Input images. c_org = c_org.to(self.device) # Original domain labels. c_trg = c_trg.to(self.device) # Target domain labels. label_org = label_org.to(self.device) # Labels for computing classification loss. label_trg = label_trg.to(self.device) # Labels for computing classification loss. # =================================================================================== # # 2. Train the discriminator # # =================================================================================== # # Compute loss with real images. out_src, out_cls = self.D(x_real) #D接受的就只是一幅图像 d_loss_real = - torch.mean(out_src) # d_loss_real最小,那么 out_src 最大==1 (针对图像) d_loss_cls = self.classification_loss(out_cls, label_org, self.dataset) #针对标签 x_fake = self.G(x_real, c_trg) #x_fake 生成一个图像数据 out_src, out_cls = self.D(x_fake.detach())#在这里表示不用求上面一行中的G的梯度 d_loss_fake = torch.mean(out_src) #假图像为0 # Compute loss for gradient penalty. #计算梯度惩罚因子alpha,根据alpha结合x_real,x_fake,输入判别网络,计算梯度,得到梯度损失函数, alpha = torch.rand(x_real.size(0), 1, 1, 1).to(self.device) # alpha是一个随机数 tensor([[[[ 0.7610]]]]) x_hat = (alpha * x_real.data + (1 - alpha) * x_fake.data).requires_grad_(True) # x_hat是一个图像大小的张量数据,随着alpha的改变而变化 out_src, _ = self.D(x_hat) #x_hat 表示梯度惩罚因子 d_loss_gp = self.gradient_penalty(out_src, x_hat) d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls + self.lambda_gp * d_loss_gp self.reset_grad() d_loss.backward() self.d_optimizer.step() # Logging. loss = {} loss['D/loss_real'] = d_loss_real.item() loss['D/loss_fake'] = d_loss_fake.item() loss['D/loss_cls'] = d_loss_cls.item() loss['D/loss_gp'] = d_loss_gp.item() # =================================================================================== # # 3. Train the generator # # =================================================================================== # #生成网络的作用是,输入original域的图可以生成目标域的图像,输入为目标域的图像,生成original域的图像(重建) if (i+1) % self.n_critic == 0: # Original-to-target domain. #将真实图像输入x_real和假的标签c_trg输入生成网络,得到生成图像x_fake x_fake = self.G(x_real, c_trg) out_src, out_cls = self.D(x_fake) g_loss_fake = - torch.mean(out_src) #这里是对抗损失,希望生成的假图像为1 g_loss_cls = self.classification_loss(out_cls, label_trg, self.dataset)#向目标标签进行转化 # Target-to-original domain. x_reconst = self.G(x_fake, c_org) g_loss_rec = torch.mean(torch.abs(x_real - x_reconst)) # Backward and optimize. g_loss = g_loss_fake + self.lambda_rec * g_loss_rec + self.lambda_cls * g_loss_cls self.reset_grad() g_loss.backward() self.g_optimizer.step() # Logging. loss['G/loss_fake'] = g_loss_fake.item() loss['G/loss_rec'] = g_loss_rec.item() loss['G/loss_cls'] = g_loss_cls.item() # =================================================================================== # # 4. Miscellaneous # # =================================================================================== # # Print out training information. if (i+1) % self.log_step == 0: et = time.time() - start_time et = str(datetime.timedelta(seconds=et))[:-7] log = "Elapsed [{}], Iteration [{}/{}]".format(et, i+1, self.num_iters) for tag, value in loss.items(): log += ", {}: {:.4f}".format(tag, value) print(log) if self.use_tensorboard: for tag, value in loss.items(): self.logger.scalar_summary(tag, value, i+1) # Translate fixed images for debugging. 每100轮保存一次图像 if (i+1) % self.sample_step == 0: with torch.no_grad(): x_fake_list = [x_fixed] x_fake_list.append(self.G(x_fixed, c_trg)) x_concat = torch.cat(x_fake_list, dim=3) sample_path = os.path.join(self.sample_dir, '{}-images.jpg'.format(i+1)) save_image(self.denorm(x_concat.data.cpu()), sample_path, nrow=1, padding=0) print('Saved real and fake images into {}...'.format(sample_path)) # Save model checkpoints. 每100轮保存一下模型 if (i+1) % self.model_save_step == 0: G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(i+1)) D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(i+1)) torch.save(self.G.state_dict(), G_path) torch.save(self.D.state_dict(), D_path) print('Saved model checkpoints into {}...'.format(self.model_save_dir)) # Decay learning rates. 降低学习率 if (i+1) % self.lr_update_step == 0 and (i+1) > (self.num_iters - self.num_iters_decay): g_lr -= (self.g_lr / float(self.num_iters_decay)) d_lr -= (self.d_lr / float(self.num_iters_decay)) self.update_lr(g_lr, d_lr) print ('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))
solver = Multi_Solver(mul_loader)
solver.train()