pytorch jupyter下的CycleGAN代码

模型用的是苹果转橘子的数据集。但可能是由于模型太大且图片数量不足(1000张左右)。因此,有些图片transform不是很好。

模型是挂在天池上面跑的。还需要导入until.py文件,我放在文末了。

import glob
import random
import os
import torch
from torch.utils.data import Dataset
from PIL import Image
import utils
import torchvision.transforms as transforms
from torch.autograd import Variable
from PIL import Image
import matplotlib.pyplot as plt
%matplotlib inline
import torchvision.utils as vutils
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import itertools
import torchvision

定义一些超参

""" gpu """
gpu_id = [0]
utils.cuda_devices(gpu_id)
# 决定我们在哪个设备上运行
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")

""" param """
epochs = 2500
batch_size = 50
size=64
lr = 0.0002
n_critic = 5
z_dim = 100

导入数据集

class ImageDataset(Dataset):
    def __init__(self, root, transforms_=None, unaligned=False, mode='train'):
        self.transform = transforms.Compose(transforms_)  # 将几个变化整合在一起
        self.unaligned = unaligned
        
        # 匹配 `数据集文件夹/(train or test)/(A or B)` 下的所有文件并打乱
        self.files_A = sorted(glob.glob(os.path.join(root, '%sA' % mode) + '/*.*'))
        self.files_B = sorted(glob.glob(os.path.join(root, '%sB' % mode) + '/*.*'))
    def __getitem__(self, index):  # `__getitem__`, 允许用户像字典一样访问数据 : X[key] -> value 
        
        item_A = self.transform(Image.open(self.files_A[index % len(self.files_A)]))

        if self.unaligned:
            # 不对齐则随机出一张图片
            item_B = self.transform(Image.open(self.files_B[random.randint(0, len(self.files_B) - 1)]))
        else:
            item_B = self.transform(Image.open(self.files_B[index % len(self.files_B)]))

        return {'A': item_A, 'B': item_B}

    def __len__(self):
        # 两者中取一张取数量大的
        return max(len(self.files_A), len(self.files_B))
# Dataset loader
transforms_ = [transforms.Resize(int(size*1.12), Image.BICUBIC), 
               transforms.RandomCrop(size), 
               transforms.RandomHorizontalFlip(), # 随机水平翻转
               transforms.ToTensor(),             # PIL.Image/np.ndarray (HWC) [0, 255] -> torch.FloatTensor (CHW) [0.0, 1.0]
               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] # 将三个通道 `Normalize`
dataloader = torch.utils.data.DataLoader(ImageDataset(r'dataset/apple2orange', transforms_=transforms_, unaligned=True), 
                        batch_size=batch_size, shuffle=True)
# 展示一些训练图片
real_batch = next(iter(dataloader))['B']
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch.to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))

定义模型

class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()

        conv_block = [  nn.ReflectionPad2d(1),
                        nn.Conv2d(in_features, in_features, 3),
                        nn.InstanceNorm2d(in_features),
                        nn.ReLU(inplace=True),   # 进行原地操作, 节省内存
                        nn.ReflectionPad2d(1),
                        nn.Conv2d(in_features, in_features, 3),
                        nn.InstanceNorm2d(in_features)  ]

        self.conv_block = nn.Sequential(*conv_block)

    def forward(self, x):
        return x + self.conv_block(x)
class Generator(nn.Module):
    def __init__(self, input_nc, output_nc, n_residual_blocks=2):
        super(Generator, self).__init__()

        # Initial convolution block       
        model = [   nn.ReflectionPad2d(3),
                    nn.Conv2d(input_nc, 64, 7),
                    nn.InstanceNorm2d(64),
                    nn.ReLU(inplace=True) ]

        # Downsampling
        in_features = 64
        out_features = in_features*2
        for _ in range(2):
            model += [  nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                        nn.InstanceNorm2d(out_features),
                        nn.ReLU(inplace=True) ]
            in_features = out_features
            out_features = in_features*2

        # Residual blocks
        for _ in range(n_residual_blocks):
            model += [ResidualBlock(in_features)]

        # Upsampling
        out_features = in_features//2
        for _ in range(2):
            model += [nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
                      nn.InstanceNorm2d(out_features),
                      nn.ReLU(inplace=True) ]
            in_features = out_features
            out_features = in_features//2

        # Output layer
        model += [nn.ReflectionPad2d(3),
                  nn.Conv2d(64, output_nc, 7),
                  nn.Tanh() ]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)
class Discriminator(nn.Module):
    def __init__(self, input_nc):
        super(Discriminator, self).__init__()

        # A bunch of convolutions one after another
        model = [nn.Conv2d(input_nc, 64, 4, stride=2, padding=1),
                 nn.LeakyReLU(0.2, inplace=True) ]

        model += [nn.Conv2d(64, 128, 4, stride=2, padding=1),
                  nn.InstanceNorm2d(128), 
                  nn.LeakyReLU(0.2, inplace=True) ]

        model += [nn.Conv2d(128, 256, 4, stride=2, padding=1),
                  nn.InstanceNorm2d(256), 
                  nn.LeakyReLU(0.2, inplace=True) ]

        model += [nn.Conv2d(256, 512, 4, padding=1),
                  nn.InstanceNorm2d(512), 
                  nn.LeakyReLU(0.2, inplace=True) ]

        # FCN classification layer
        model += [nn.Conv2d(512, 1, 4, padding=1)]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        x =  self.model(x)
        # Globel average pooling and flatten
        return F.avg_pool2d(x, x.shape[2:]).view(x.shape[0], -1)

实例化模型

netG_A2B = Generator(3, 3)
netG_B2A = Generator(3, 3)
netD_A = Discriminator(3)
netD_B = Discriminator(3)
criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()
utils.cuda([netG_A2B, netG_B2A, netD_A, netD_B, criterion_GAN, criterion_cycle, criterion_identity])
# Optimizers & LR schedulers
optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()),  # `itertools.chain` 相当于把两个参数结合在一起了
                               lr=lr, betas=(0.5, 0.999))
optimizer_D_A = torch.optim.Adam(netD_A.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D_B = torch.optim.Adam(netD_B.parameters(), lr=lr, betas=(0.5, 0.999))

每次训练的语句停了,都要重新运行这句,把保存的最新模型变成本次运行的模型

""" load checkpoint """
ckpt_dir = './checkpoints1/celeba_cyclegan'
utils.mkdir(ckpt_dir)
try:
    ckpt = utils.load_checkpoint(ckpt_dir)
    start_epoch = ckpt['epoch']
    netD_A.load_state_dict(ckpt['netD_A'])
    netD_B.load_state_dict(ckpt['netD_B'])
    netG_A2B.load_state_dict(ckpt['netG_A2B'])
    netG_B2A.load_state_dict(ckpt['netG_B2A'])
    optimizer_G.load_state_dict(ckpt['optimizer_G'])
    optimizer_D_A.load_state_dict(ckpt['optimizer_D_A'])
    optimizer_D_B.load_state_dict(ckpt['optimizer_D_B'])
except:
    print(' [*] No checkpoint!')
    start_epoch = 0
class ReplayBuffer():
    def __init__(self, max_size=50):
        assert (max_size > 0), 'Empty buffer or trying to create a black hole. Be careful.'
        self.max_size = max_size
        self.data = []

    def push_and_pop(self, data):
        to_return = []
        for element in data.data:
            element = torch.unsqueeze(element, 0) # 在指定位置添加一个维度
            if len(self.data) < self.max_size:
                self.data.append(element)
                to_return.append(element)
            else:
                if random.uniform(0,1) > 0.5:
                    i = random.randint(0, self.max_size-1)
                    to_return.append(self.data[i].clone())  # torch.Tensor.clone 相当于 .copy
                    self.data[i] = element
                else:
                    to_return.append(element)
        return Variable(torch.cat(to_return))

定义一些用到的变量

# Inputs & targets memory allocation
Tensor = torch.cuda.FloatTensor
input_A = Tensor(batch_size, 3, size, size)
input_B = Tensor(batch_size, 3, size, size)
target_real = Variable(Tensor(batch_size).fill_(1.0), requires_grad=False)
target_fake = Variable(Tensor(batch_size).fill_(0.0), requires_grad=False)

fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()
A = [] # 用来显示图片的
B = []
for epoch in range(start_epoch, epochs):
    for i, batch in enumerate(dataloader):
        if i == len(dataloader) - 1:
            continue
        # Set model input (X, 3, H, W)
        real_A = Variable(input_A.copy_(batch['A']))
        real_B = Variable(input_B.copy_(batch['B']))
        
        real_A, real_B, target_real, target_fake = utils.cuda([real_A, real_B, target_real, target_fake])
        
        #-------- Generators A2B and B2A --------
        optimizer_G.zero_grad()

        # Identity loss
        # G_A2B(B) should equal B if real B is fed
        same_B = netG_A2B(real_B)
        loss_identity_B = criterion_identity(same_B, real_B)*5.0   # 0 维变量
        # G_B2A(A) should equal A if real A is fed
        same_A = netG_B2A(real_A)
        loss_identity_A = criterion_identity(same_A, real_A)*5.0   # 0 维变量
        
        # GAN loss
        fake_B = netG_A2B(real_A)
        pred_fake = netD_B(fake_B)
        # 此处有 `UserWarning` : [1], [1, 1] 不匹配, 但是不影响操作
        loss_GAN_A2B = criterion_GAN(pred_fake, target_real)       # 0 维变量
        
        fake_A = netG_B2A(real_B)
        pred_fake = netD_A(fake_A)
        # 此处有 `UserWarning` : [1], [1, 1] 不匹配, 但是不影响操作
        loss_GAN_B2A = criterion_GAN(pred_fake, target_real)       # 0 维变量
        
        # Cycle loss
        recovered_A = netG_B2A(fake_B)
        loss_cycle_ABA = criterion_cycle(recovered_A, real_A)*10.0 # 0 维变量

        recovered_B = netG_A2B(fake_A)
        loss_cycle_BAB = criterion_cycle(recovered_B, real_B)*10.0 # 0 维变量

        # Total loss
        loss_G = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB
        loss_G.backward()
        
        optimizer_G.step()
        
        #-------- Discriminator A --------
        optimizer_D_A.zero_grad()
        
        # Real loss
        pred_real = netD_A(real_A)
        loss_D_real = criterion_GAN(pred_real, target_real)

        # Fake loss
        fake_A = fake_A_buffer.push_and_pop(fake_A)
        fake_A = utils.cuda(fake_A)
        pred_fake = netD_A(fake_A.detach())
        loss_D_fake = criterion_GAN(pred_fake, target_fake)

        # Total loss
        loss_D_A = (loss_D_real + loss_D_fake)*0.5
        loss_D_A.backward()

        optimizer_D_A.step()
        #-------- Discriminator B --------
        optimizer_D_B.zero_grad()

        # Real loss
        pred_real = netD_B(real_B)
        loss_D_real = criterion_GAN(pred_real, target_real)
        
        # Fake loss
        fake_B = fake_B_buffer.push_and_pop(fake_B)
        fake_B = utils.cuda(fake_B)
        pred_fake = netD_B(fake_B.detach())
        loss_D_fake = criterion_GAN(pred_fake, target_fake)

        # Total loss
        loss_D_B = (loss_D_real + loss_D_fake)*0.5
        loss_D_B.backward()

        optimizer_D_B.step()
        ###################################
        
        if (i + 1) % 15 == 0:
            print("Epoch: (%3d) (%5d/%5d)" % (epoch, i + 1, len(dataloader)))

    if (epoch + 1) % 5 == 0:  # 因为我训练了近2000次,所以我每5个epoch存一次图片
        save_dir = './sample_images_while_training/cycleGAN'
        utils.mkdir(save_dir)
        # torchvision.utils.save_image(real_A, '%s/Epoch_(%d)_(%dof%d)_real_A.jpg' % (save_dir, epoch, i + 1, len(dataloader)), nrow=10)
        # torchvision.utils.save_image(real_B, '%s/Epoch_(%d)_(%dof%d)_real_B.jpg' % (save_dir, epoch, i + 1, len(dataloader)), nrow=10)
        torchvision.utils.save_image(fake_A, '%s/Epoch_(%d)_(%dof%d)_fake_A.jpg' % (save_dir, epoch, i + 1, len(dataloader)), nrow=10)
        torchvision.utils.save_image(fake_B, '%s/Epoch_(%d)_(%dof%d)_fake_B.jpg' % (save_dir, epoch, i + 1, len(dataloader)), nrow=10)

        with torch.no_grad():

            A.append(vutils.make_grid(fake_A.detach().cpu(), padding=2, normalize=True))
            B.append(vutils.make_grid(fake_B.detach().cpu(), padding=2, normalize=True))
                
    utils.save_checkpoint({'epoch': epoch + 1,
                           'netD_A': netD_A.state_dict(),
                           'netD_B': netD_B.state_dict(),
                           'netG_A2B': netG_A2B.state_dict(),
                           'netG_B2A': netG_B2A.state_dict(),
                           'optimizer_G': optimizer_G.state_dict(),
                           'optimizer_D_A': optimizer_D_A.state_dict(),
                           'optimizer_D_B': optimizer_D_B.state_dict(),},
                          '%s/Epoch_(%d).ckpt' % (ckpt_dir, epoch + 1),
                          max_keep=2)

显示训练图片

# 画出真实图像
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("A")
plt.imshow(np.transpose(A[1], (1,2,0)))

# 画出来自最后一次训练的假图像
plt.subplot(1,2,2)
plt.axis("off")
plt.title("B")
plt.imshow(np.transpose(B[1],(1,2,0)))
plt.show()

 

 untils.py文件,其中定义了转cuda,保存模型,调用模型等函数

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import shutil

import torch


def mkdir(paths):
    if not isinstance(paths, (list, tuple)):
        paths = [paths]
    for path in paths:
        if not os.path.isdir(path):
            os.makedirs(path)


def cuda_devices(gpu_ids):
    gpu_ids = [str(i) for i in gpu_ids]
    os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(gpu_ids)


def cuda(xs):
    if torch.cuda.is_available():
        if not isinstance(xs, (list, tuple)):
            return xs.cuda()
        else:
            return [x.cuda() for x in xs]


def save_checkpoint(state, save_path, is_best=False, max_keep=None):
    # save checkpoint
    torch.save(state, save_path)

    # deal with max_keep
    save_dir = os.path.dirname(save_path)
    list_path = os.path.join(save_dir, 'latest_checkpoint')

    save_path = os.path.basename(save_path)
    if os.path.exists(list_path):
        with open(list_path) as f:
            ckpt_list = f.readlines()
            ckpt_list = [save_path + '\n'] + ckpt_list
    else:
        ckpt_list = [save_path + '\n']

    if max_keep is not None:
        for ckpt in ckpt_list[max_keep:]:
            ckpt = os.path.join(save_dir, ckpt[:-1])
            if os.path.exists(ckpt):
                os.remove(ckpt)
        ckpt_list[max_keep:] = []

    with open(list_path, 'w') as f:
        f.writelines(ckpt_list)

    # copy best
    if is_best:
        shutil.copyfile(save_path, os.path.join(save_dir, 'best_model.ckpt'))


def load_checkpoint(ckpt_dir_or_file, map_location=None, load_best=False):
    if os.path.isdir(ckpt_dir_or_file):
        if load_best:
            ckpt_path = os.path.join(ckpt_dir_or_file, 'best_model.ckpt')
        else:
            with open(os.path.join(ckpt_dir_or_file, 'latest_checkpoint')) as f:
                ckpt_path = os.path.join(ckpt_dir_or_file, f.readline()[:-1])
    else:
        ckpt_path = ckpt_dir_or_file
    ckpt = torch.load(ckpt_path, map_location=map_location)
    print(' [*] Loading checkpoint from %s succeed!' % ckpt_path)
    return ckpt

 

posted @ 2021-02-08 19:24  安智伟  阅读(589)  评论(0编辑  收藏  举报