加标签的starGAN(MultiPIE数据)

import sys
sys.path.append("/home/hxj/anaconda3/lib/python3.6/site-packages")
from torchvision.datasets import ImageFolder
from PIL import Image
import torch
import os
import random

c_dim=5 # dimension of domain labels (1st dataset)
c2_dim=8 # dimension of domain labels (2nd dataset)
celeba_crop_size=178 # crop size for the CelebA dataset
rafd_crop_size=256 #crop size for the RaFD dataset
image_size=128 #image resolution
g_conv_dim=64 # number of conv filters in the first layer of G
d_conv_dim=64 # number of conv filters in the first layer of D
g_repeat_num = 6 #number of residual blocks in G
d_repeat_num=6  #number of strided conv layers in D
lambda_cls=1 #weight for domain classification loss
lambda_rec=10 # weight for reconstruction loss
lambda_gp=10 #'weight for gradient penalty
    
    # Training configuration.
dataset='CelebA' # choices=['CelebA', 'RaFD', 'Both'])
batch_size=16 # 'mini-batch size
num_iters=200000 #number of total iterations for training D
num_iters_decay=100000 #number of iterations for decaying lr
g_lr=0.0001 #learning rate for G
d_lr=0.0001 #learning rate for D
n_critic=5 #number of D updates per each G update
beta1=0.5 #beta1 for Adam optimizer
beta2=0.999 #beta2 for Adam optimizer
resume_iters=None #resume training from this step
selected_attrs=['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young']
    #selected attributes for the CelebA dataset'

    # Test configuration.
test_iters=200000 #test model from this step

    # Miscellaneous.
num_workers=1
mode='test'  # choices=['train', 'test'])
use_tensorboard=True

    # Directories.
celeba_image_dir='../data/CelebA_nocrop/images/' if mode == 'train' else '../test/test/' 
attr_path='../data/list_attr_celeba.txt' if mode == 'train' else '../test/test_celeba.txt'
rafd_image_dir='../RaFD/train'
log_dir='../test/logs'
model_save_dir='../stargan/models'
sample_dir='../test/samples'
result_dir='../test/result'

    # Step size.
log_step=10
sample_step=1000
model_save_step=10000
lr_update_step=1000
import tensorflow as tf
#这是加载TensorBord
class Logger(object):
    """Tensorboard logger."""

    def __init__(self, log_dir):
        """Initialize summary writer."""
        self.writer = tf.summary.FileWriter(log_dir)

    def scalar_summary(self, tag, value, step):
        """Add scalar summary."""
        summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)])
        self.writer.add_summary(summary, step)
from torch.utils import data
from torchvision import transforms as T
class MulPIE(data.Dataset):
    """Dataset class for the MultiPIE dataset."""

    def __init__(self, image_dir,  transform, mode):
        """Initialize and preprocess the CelebA dataset."""
        self.image_dir = image_dir
        self.transform = transform
        self.mode = mode
        self.train_dataset = []
        self.test_dataset = []
        self.preprocess()
        
        if mode == 'train':
            self.num_images = len(self.train_dataset)
        else:
            self.num_images = len(self.test_dataset)
       
    def preprocess(self):
        dataset1 = ImageFolder(self.image_dir, self.transform)
        for names in dataset1.imgs:
            lable=[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]
            name = names[0]
            num = int(name[-6:-4])
            lable[num]=1
            label_trg = [0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0]
            self.train_dataset.append([name, lable, label_trg])
            self.test_dataset.append([name, lable, label_trg])
    
        #该方法是继承torch里面的utils文件夹里面data文件夹里面的Dataset类
    def __getitem__(self, index):
        """Return one image and its corresponding attribute label."""
        dataset = self.train_dataset if self.mode == 'train' else self.test_dataset
        filename, label, label_trg= dataset[index]
        #filename, label = dataset[index]
        image = Image.open(os.path.join(self.image_dir, filename))
        return self.transform(image), torch.FloatTensor(label), torch.FloatTensor(label_trg)
        #return self.transform(image), torch.FloatTensor(label)

    def __len__(self):
        """Return the number of images."""
        return self.num_images

def get_loader(image_dir, image_size=128, batch_size=16, mode='train', num_workers=1):
    """Build and return a data loader."""
    transform = []
    if mode == 'train':
        transform.append(T.RandomHorizontalFlip())
    #transform1.append(T.CenterCrop(178)) 以后会用到裁剪图像
      #to run only once
    transform.append(T.Resize(image_size))
    transform.append(T.ToTensor())
    transform.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
    transform = T.Compose(transform)
    

    dataset = MulPIE(image_dir, transform, mode)
    
    data_loader = data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=(mode=='train'),
                                  num_workers=num_workers)
    return data_loader


image_dir='../RaFD/'
mul_loader = get_loader(image_dir,image_size,batch_size, mode, num_workers)
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class ResidualBlock(nn.Module):
    """Residual Block with instance normalization."""
    def __init__(self, dim_in, dim_out):
        super(ResidualBlock, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(dim_out, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(dim_out, affine=True, track_running_stats=True))

    def forward(self, x):
        return x + self.main(x)


class Generator(nn.Module):
    """Generator network."""
    def __init__(self, conv_dim=64, c_dim=5, repeat_num=6):
        super(Generator, self).__init__()

        layers = []
        # 第一个卷积层,输入为图像和label的串联,3表示图像为3通道,c_dim为label的维度,
        layers.append(nn.Conv2d(3+c_dim, conv_dim, kernel_size=7, stride=1, padding=3, bias=False))
        layers.append(nn.InstanceNorm2d(conv_dim, affine=True, track_running_stats=True))
        layers.append(nn.ReLU(inplace=True))

        # Down-sampling layers.
        curr_dim = conv_dim #这时候的64个维度
        for i in range(2):
            layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1, bias=False))
            layers.append(nn.InstanceNorm2d(curr_dim*2, affine=True, track_running_stats=True))
            layers.append(nn.ReLU(inplace=True))
            curr_dim = curr_dim * 2
            
        #经过两次循环,这时 curr_dim 的维度为256
        # Bottleneck layers.
        for i in range(repeat_num):
            layers.append(ResidualBlock(dim_in=curr_dim, dim_out=curr_dim))
        
        # Up-sampling layers.
        for i in range(2):
            layers.append(nn.ConvTranspose2d(curr_dim, curr_dim//2, kernel_size=4, stride=2, padding=1, bias=False))
            layers.append(nn.InstanceNorm2d(curr_dim//2, affine=True, track_running_stats=True))
            layers.append(nn.ReLU(inplace=True))
            curr_dim = curr_dim // 2
            
        #最后的维度为3维
        layers.append(nn.Conv2d(curr_dim, 3, kernel_size=7, stride=1, padding=3, bias=False))
        layers.append(nn.Tanh())
        self.main = nn.Sequential(*layers)

    def forward(self, x, c): #定义计算的过程
        # Replicate spatially and concatenate domain information.
        c = c.view(c.size(0), c.size(1), 1, 1) #view 相当于Numpy中的reshape
        c = c.repeat(1, 1, x.size(2), x.size(3)) #沿着指定的维度重复tensor
        x = torch.cat([x, c], dim=1) #将输入图像x,label向量c,串联
        return self.main(x)


class Discriminator(nn.Module):
    """Discriminator network with PatchGAN."""
    def __init__(self, image_size=128, conv_dim=64, c_dim=5, repeat_num=6):
        super(Discriminator, self).__init__()
        layers = []
        layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1))
        layers.append(nn.LeakyReLU(0.01))

        curr_dim = conv_dim
        for i in range(1, repeat_num):
            layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1))
            layers.append(nn.LeakyReLU(0.01))
            curr_dim = curr_dim * 2

        kernel_size = int(image_size / np.power(2, repeat_num))
        self.main = nn.Sequential(*layers) #将层加入到神经网络
        self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=3, stride=1, padding=1, bias=False)#D判读图像的真假
        self.conv2 = nn.Conv2d(curr_dim, c_dim, kernel_size=kernel_size, bias=False)#判别输入图像的label.
        
    def forward(self, x):
        h = self.main(x)     #这里的X表示训练时的图像,经过main()后生成2048维数据
        out_src = self.conv1(h) #out_src 表示图像的真假
        out_cls = self.conv2(h) # out_cls 表示图像的标签
        return out_src, out_cls.view(out_cls.size(0), out_cls.size(1))
from torchvision.utils import save_image
import time
import datetime

class Multi_Solver(object):
    """Solver for training and testing StarGAN."""

    def __init__(self, mul_loader):
        """Initialize configurations."""

        # Data loader.
        self.celeba_loader = mul_loader
        # Model configurations.
        self.c_dim = 20
        self.c2_dim = c2_dim
        self.image_size = image_size
        self.g_conv_dim = g_conv_dim
        self.d_conv_dim = d_conv_dim
        self.g_repeat_num = g_repeat_num
        self.d_repeat_num = d_repeat_num
        self.lambda_cls = lambda_cls
        self.lambda_rec = lambda_rec
        self.lambda_gp = lambda_gp

        # Training configurations.
        self.dataset = dataset
        self.batch_size = batch_size
        self.num_iters = num_iters
        self.num_iters_decay = num_iters_decay
        self.g_lr = g_lr
        self.d_lr = d_lr
        self.n_critic = n_critic
        self.beta1 = beta1
        self.beta2 = beta2
        self.resume_iters = resume_iters
        self.selected_attrs = selected_attrs

        # Test configurations.
        self.test_iters = test_iters

        # Miscellaneous.
        self.use_tensorboard = use_tensorboard
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        #self.device = torch.device('cpu')
        
        # Directories.
        self.log_dir = log_dir
        self.sample_dir = sample_dir
        self.model_save_dir = model_save_dir
        self.result_dir = result_dir

        # Step size.
        self.log_step = log_step
        self.sample_step = sample_step
        self.model_save_step = model_save_step
        self.lr_update_step = lr_update_step

        # Build the model and tensorboard.
        self.build_model()
        if self.use_tensorboard:
            self.build_tensorboard()
    
    def build_model(self):
        """Create a generator and a discriminator."""
        if self.dataset in ['CelebA', 'RaFD']:
            self.G = Generator(self.g_conv_dim, self.c_dim, self.g_repeat_num)
            self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim, self.d_repeat_num) 
        elif self.dataset in ['Both']:
            self.G = Generator(self.g_conv_dim, self.c_dim+self.c2_dim+2, self.g_repeat_num)   # 2 for mask vector.
            self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim+self.c2_dim, self.d_repeat_num)

        self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2])
        self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr, [self.beta1, self.beta2])
        
        #打印网络结构
        #self.print_network(self.G, 'G')
        #self.print_network(self.D, 'D')
            
        self.G.to(self.device)
        self.D.to(self.device)

        
    def denorm(self, x):
        """Convert the range from [-1, 1] to [0, 1]."""
        out = (x + 1) / 2
        return out.clamp_(0, 1)
    

    def build_tensorboard(self):
        self.logger = Logger(self.log_dir)

    def restore_model(self, resume_iters):
        """Restore the trained generator and discriminator."""
        print('Loading the trained models from step {}...'.format(resume_iters))
        G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(resume_iters))
        D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(resume_iters))
        self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
        self.D.load_state_dict(torch.load(D_path, map_location=lambda storage, loc: storage))
        
    def update_lr(self, g_lr, d_lr):
        """Decay learning rates of the generator and discriminator."""
        for param_group in self.g_optimizer.param_groups:
            param_group['lr'] = g_lr
        for param_group in self.d_optimizer.param_groups:
            param_group['lr'] = d_lr

    def reset_grad(self):
        """Reset the gradient buffers."""
        self.g_optimizer.zero_grad()
        self.d_optimizer.zero_grad()
        
    def classification_loss(self, logit, target, dataset='CelebA'):
        """Compute binary or softmax cross entropy loss."""
        if dataset == 'CelebA':
            return F.binary_cross_entropy_with_logits(logit, target, size_average=False) / logit.size(0)
        elif dataset == 'RaFD':
            return F.cross_entropy(logit, target)
        
    def gradient_penalty(self, y, x):
        """Compute gradient penalty: (L2_norm(dy/dx) - 1)**2."""
        weight = torch.ones(y.size()).to(self.device)
        dydx = torch.autograd.grad(outputs=y,
                                   inputs=x,
                                   grad_outputs=weight,
                                   retain_graph=True,
                                   create_graph=True,
                                   only_inputs=True)[0]

        dydx = dydx.view(dydx.size(0), -1)
        dydx_l2norm = torch.sqrt(torch.sum(dydx**2, dim=1))
        return torch.mean((dydx_l2norm-1)**2)

    def train(self):
        """Train StarGAN within a single dataset."""
        # Set data loader.
        data_loader = self.celeba_loader
        data_iter = iter(data_loader)
        # Learning rate cache for decaying.
        g_lr = self.g_lr
        d_lr = self.d_lr
        
        x_fixed, c_org,c_trg = next(data_iter)
        x_fixed = x_fixed.to(self.device)
        # Start training from scratch or resume training.
        start_iters = 0  
        if self.resume_iters: #参数resume_iters 设置为none 
            start_iters = self.resume_iters #可以不连续训练,从之前训练好后的结果处开始
            self.restore_model(self.resume_iters)
        
        # Start training.
        print('Start training...')
        start_time = time.time()
        for i in range(start_iters, self.num_iters):

            # =================================================================================== #
            #                             1. Preprocess input data                                #
            # =================================================================================== #

            # Fetch real images and labels.
            try:
                x_real, label_org, label_trg = next(data_iter)
            except:
                data_iter = iter(data_loader)
                x_real, label_org, label_trg= next(data_iter)
           
            # Generate target domain labels randomly.
        
            #rand_idx = torch.randperm(label_org.size(0)) #取出真实标签,然后在真实标签里面随机选取一个
            #label_trg = label_org[rand_idx] # 真实label中任意选取一个标签
            c_org = label_org.clone()
            c_trg = label_trg.clone()
            
            x_real = x_real.to(self.device)           # Input images.
            c_org = c_org.to(self.device)             # Original domain labels.
            c_trg = c_trg.to(self.device)             # Target domain labels.
            label_org = label_org.to(self.device)     # Labels for computing classification loss.
            label_trg = label_trg.to(self.device)     # Labels for computing classification loss.
                        
            # =================================================================================== #
            #                             2. Train the discriminator                              #
            # =================================================================================== #

            # Compute loss with real images.
            out_src, out_cls = self.D(x_real) #D接受的就只是一幅图像
            d_loss_real = - torch.mean(out_src) # d_loss_real最小,那么 out_src 最大==1 (针对图像)
            d_loss_cls = self.classification_loss(out_cls, label_org, self.dataset) #针对标签 
            x_fake = self.G(x_real, c_trg) #x_fake 生成一个图像数据
            out_src, out_cls = self.D(x_fake.detach())#在这里表示不用求上面一行中的G的梯度
            d_loss_fake = torch.mean(out_src) #假图像为0 
            
            # Compute loss for gradient penalty.
            #计算梯度惩罚因子alpha,根据alpha结合x_real,x_fake,输入判别网络,计算梯度,得到梯度损失函数,
            alpha = torch.rand(x_real.size(0), 1, 1, 1).to(self.device) 
            # alpha是一个随机数 tensor([[[[ 0.7610]]]])
            x_hat = (alpha * x_real.data + (1 - alpha) * x_fake.data).requires_grad_(True)
            # x_hat是一个图像大小的张量数据,随着alpha的改变而变化
            out_src, _ = self.D(x_hat) #x_hat 表示梯度惩罚因子
            d_loss_gp = self.gradient_penalty(out_src, x_hat)
            d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls + self.lambda_gp * d_loss_gp
            self.reset_grad()
            d_loss.backward()
            self.d_optimizer.step()

            # Logging.
            loss = {}
            loss['D/loss_real'] = d_loss_real.item()
            loss['D/loss_fake'] = d_loss_fake.item()
            loss['D/loss_cls'] = d_loss_cls.item()
            loss['D/loss_gp'] = d_loss_gp.item()
            
            # =================================================================================== #
            #                               3. Train the generator                                #
            # =================================================================================== #
            #生成网络的作用是,输入original域的图可以生成目标域的图像,输入为目标域的图像,生成original域的图像(重建)
            if (i+1) % self.n_critic == 0:
                # Original-to-target domain.
                #将真实图像输入x_real和假的标签c_trg输入生成网络,得到生成图像x_fake
                x_fake = self.G(x_real, c_trg)
                out_src, out_cls = self.D(x_fake)
                g_loss_fake = - torch.mean(out_src) #这里是对抗损失,希望生成的假图像为1
                g_loss_cls = self.classification_loss(out_cls, label_trg, self.dataset)#向目标标签进行转化

                # Target-to-original domain.
                x_reconst = self.G(x_fake, c_org)
                g_loss_rec = torch.mean(torch.abs(x_real - x_reconst))

                # Backward and optimize.
                g_loss = g_loss_fake + self.lambda_rec * g_loss_rec + self.lambda_cls * g_loss_cls
                self.reset_grad()
                g_loss.backward()
                self.g_optimizer.step()

                # Logging.
                loss['G/loss_fake'] = g_loss_fake.item()
                loss['G/loss_rec'] = g_loss_rec.item()
                loss['G/loss_cls'] = g_loss_cls.item()

            # =================================================================================== #
            #                                 4. Miscellaneous                                    #
            # =================================================================================== #

            # Print out training information.
            if (i+1) % self.log_step == 0:
                et = time.time() - start_time
                et = str(datetime.timedelta(seconds=et))[:-7]
                log = "Elapsed [{}], Iteration [{}/{}]".format(et, i+1, self.num_iters)
                for tag, value in loss.items():
                    log += ", {}: {:.4f}".format(tag, value)
                print(log)

                if self.use_tensorboard:
                    for tag, value in loss.items():
                        self.logger.scalar_summary(tag, value, i+1)

            # Translate fixed images for debugging. 每100轮保存一次图像
            if (i+1) % self.sample_step == 0:
                with torch.no_grad():
                    x_fake_list = [x_fixed]
                   
                    x_fake_list.append(self.G(x_fixed, c_trg))
                    x_concat = torch.cat(x_fake_list, dim=3)
                    sample_path = os.path.join(self.sample_dir, '{}-images.jpg'.format(i+1))
                    save_image(self.denorm(x_concat.data.cpu()), sample_path, nrow=1, padding=0)
                    print('Saved real and fake images into {}...'.format(sample_path))

            # Save model checkpoints. 每100轮保存一下模型
            if (i+1) % self.model_save_step == 0:
                G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(i+1))
                D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(i+1))
                torch.save(self.G.state_dict(), G_path)
                torch.save(self.D.state_dict(), D_path)
                print('Saved model checkpoints into {}...'.format(self.model_save_dir))

            # Decay learning rates. 降低学习率
            if (i+1) % self.lr_update_step == 0 and (i+1) > (self.num_iters - self.num_iters_decay):
                g_lr -= (self.g_lr / float(self.num_iters_decay))
                d_lr -= (self.d_lr / float(self.num_iters_decay))
                self.update_lr(g_lr, d_lr)
                print ('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))
solver = Multi_Solver(mul_loader)
solver.train()

 

posted @ 2018-08-14 10:23  白菜hxj  阅读(1436)  评论(0编辑  收藏  举报