GAN生成人脸代码

github:https://github.com/shixiaojia/GAN.git

基于GAN介绍的理论,简单实现GAN生成人脸,代码如下:

utils.py

import os
import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset
import cv2
import glob


class MyDataset(Dataset):
def __init__(self, img_path, device):
super(MyDataset, self).__init__()
self.device = device
self.fnames = glob.glob(os.path.join(img_path+"*.jpg"))
self.transforms = transforms.Compose([
transforms.ToTensor(),
])

def __getitem__(self, idx):
fname = self.fnames[idx]
img = cv2.imread(fname, cv2.IMREAD_COLOR)
img = cv2.resize(img, (64, 64))
img = self.transforms(img)
img = img.to(self.device)
return img

def __len__(self):
return len(self.fnames)


def gradient_penality(discriminator, real, fake, device='cpu'):
b, c, h, w = real.shape
alpha = torch.randn((b, 1, 1, 1)).repeat(1, c, h, w).to(device)
interpolated_images = (real*alpha + fake * (1 - alpha)).requires_grad_(True)
scores = discriminator(interpolated_images)

gradient = torch.autograd.grad(inputs=interpolated_images,
outputs=scores,
grad_outputs=torch.ones_like(scores),
create_graph=True,
retain_graph=True,
only_inputs=True)[0]

gradient = gradient.view(gradient.shape[0], -1)
gradient_norm = gradient.norm(2, dim=1)
gradient_penality = torch.mean((gradient_norm - 1).square())
return gradient_penality


if __name__ == "__main__":
img_path = '/home/shixiaojia/dl/datasets/faces'
MyDataset(img_path)

 

network.py

import torch
import torch.nn as nn


class Reshape(nn.Module):
    def __init__(self):
        super(Reshape, self).__init__()

    def forward(self, x):
        x = x.view(-1, 3, 4, 4)
        return x


class Generator(nn.Module):
    def __init__(self, in_size):
        super(Generator, self).__init__()
        self.l1 = nn.Linear(in_features=in_size,out_features=4*4*3)
        self.reshape = Reshape()
        self.conv = nn.Sequential(
            nn.ConvTranspose2d(in_channels=3, out_channels=64, kernel_size=3, stride=2, output_padding=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=3, stride=2, output_padding=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=3, stride=2, output_padding=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=64, out_channels=3, kernel_size=3, stride=2, output_padding=1, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.l1(x)
        x = self.reshape(x)
        x = self.conv(x)
        return x


class Discriminator(nn.Module):
    '''
    shape (N, 3, 64, 64)
    '''
    def __init__(self, in_size=3, size=64):
        super(Discriminator, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(in_size, size, 5, 2, 2),
            nn.LeakyReLU(0.2),
            conv_bn_relu(size, 2*size),
            conv_bn_relu(2*size, 4*size),
            conv_bn_relu(4*size, 8*size),
            nn.Conv2d(8*size, 1, 4),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.conv(x)
        x = x.view(-1)
        return x

 

train_vanilla_gan.py

import os
import torch
import torch.nn as nn
from utils import MyDataset
from torch.utils.data import DataLoader
from network import Generator, Discriminator
from tqdm import tqdm
import argparse
from torchvision.utils import save_image

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def argsParser():
    parser = argparse.ArgumentParser(description='prepare for training.')
    parser.add_argument('--batch_size', type=int, default=16, help='batch size for training.')
    parser.add_argument('--n_epoch', type=int, default=100, help='num of epochs for training.')
    parser.add_argument('--learning_rate', type=float, default=1e-4, help='learning rate for training.')
    parser.add_argument('--noisy_dim', type=int, default=128, help='dims for noisy.')
    parser.add_argument('--num_workers', type=int, default=1, help='num of workers for dataloader.')
    return parser.parse_args()


def train(args):
    gen = Generator(args.noisy_dim)
    dis = Discriminator()

    gen = gen.to(DEVICE)
    dis = dis.to(DEVICE)

    criterion = nn.BCELoss()

    gen_opt = torch.optim.Adam(gen.parameters(), lr=args.learning_rate)
    dis_opt = torch.optim.Adam(dis.parameters(), lr=args.learning_rate)

    my_dataset = MyDataset(img_path='../faces/', device=DEVICE)

    dataloader = DataLoader(dataset=my_dataset, batch_size=args.batch_size,shuffle=True, num_workers=args.num_workers)
    for e in range(args.n_epoch):
        gen.train()
        dis.train()
        total_gen_loss = 0.
        total_dis_loss = 0.
        step = 0
        for idx, data in enumerate(tqdm(dataloader,  desc='Epoch {}: '.format(e))):
            data = data.to(DEVICE)
            N, *_ = data.shape
            noisy = torch.randn(N, args.noisy_dim).to(DEVICE)
            r_imgs = data
            r_label = torch.ones((N, )).to(DEVICE)
            f_label = torch.zeros((N, )).to(DEVICE)

            f_imgs = gen(noisy)

            r_logit = dis(r_imgs)
            f_logit = dis(f_imgs.detach())

            # discriminator loss
            r_loss = criterion(r_logit, r_label)
            f_loss = criterion(f_logit, f_label)
            loss_dis = (r_loss + f_loss)/2
            total_dis_loss += loss_dis
            dis_opt.zero_grad()
            loss_dis.backward()
            dis_opt.step()

            # train generator
            f_logit = dis(f_imgs)

            loss_gen = criterion(f_logit, r_label)

            total_gen_loss += loss_gen

            gen_opt.zero_grad()
            loss_gen.backward()
            gen_opt.step()

            step += 1

        gen = gen.eval()
        noisy = torch.randn(64, args.noisy_dim).to(DEVICE)
        images = gen(noisy)

        fname = './my_generated-images-{0:0=4d}.png'.format(e)
        save_image(images, fname, nrow=8)

        if not os.path.exists('./logs/{}'.format(e)):
            os.makedirs('./logs/{}'.format(e), exist_ok=True)

        if e % 10 == 0:
            if not os.path.exists('./ckpt'):
                os.makedirs('./ckpt', exist_ok=True)
            torch.save(gen.state_dict(), './ckpt/gen_{}.pth'.format(e))
            torch.save(dis.state_dict(), './ckpt/dis_{}.pth'.format(e))


if __name__ == '__main__':
    args = argsParser()
    train(args)

 

没有调参,训练50个epoch,模型生成的结果如下:

 

生成效果不好,后续有时间需要优化一下。

 

posted @ 2024-06-30 20:08  指间的执着  阅读(10)  评论(0编辑  收藏  举报