GAN生成人脸代码
github:https://github.com/shixiaojia/GAN.git
基于GAN介绍的理论,简单实现GAN生成人脸,代码如下:
utils.py
import os
import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset
import cv2
import glob
class MyDataset(Dataset):
def __init__(self, img_path, device):
super(MyDataset, self).__init__()
self.device = device
self.fnames = glob.glob(os.path.join(img_path+"*.jpg"))
self.transforms = transforms.Compose([
transforms.ToTensor(),
])
def __getitem__(self, idx):
fname = self.fnames[idx]
img = cv2.imread(fname, cv2.IMREAD_COLOR)
img = cv2.resize(img, (64, 64))
img = self.transforms(img)
img = img.to(self.device)
return img
def __len__(self):
return len(self.fnames)
def gradient_penality(discriminator, real, fake, device='cpu'):
b, c, h, w = real.shape
alpha = torch.randn((b, 1, 1, 1)).repeat(1, c, h, w).to(device)
interpolated_images = (real*alpha + fake * (1 - alpha)).requires_grad_(True)
scores = discriminator(interpolated_images)
gradient = torch.autograd.grad(inputs=interpolated_images,
outputs=scores,
grad_outputs=torch.ones_like(scores),
create_graph=True,
retain_graph=True,
only_inputs=True)[0]
gradient = gradient.view(gradient.shape[0], -1)
gradient_norm = gradient.norm(2, dim=1)
gradient_penality = torch.mean((gradient_norm - 1).square())
return gradient_penality
if __name__ == "__main__":
img_path = '/home/shixiaojia/dl/datasets/faces'
MyDataset(img_path)
network.py
import torch import torch.nn as nn class Reshape(nn.Module): def __init__(self): super(Reshape, self).__init__() def forward(self, x): x = x.view(-1, 3, 4, 4) return x class Generator(nn.Module): def __init__(self, in_size): super(Generator, self).__init__() self.l1 = nn.Linear(in_features=in_size,out_features=4*4*3) self.reshape = Reshape() self.conv = nn.Sequential( nn.ConvTranspose2d(in_channels=3, out_channels=64, kernel_size=3, stride=2, output_padding=1, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=3, stride=2, output_padding=1, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=3, stride=2, output_padding=1, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.ConvTranspose2d(in_channels=64, out_channels=3, kernel_size=3, stride=2, output_padding=1, padding=1), nn.Sigmoid() ) def forward(self, x): x = self.l1(x) x = self.reshape(x) x = self.conv(x) return x class Discriminator(nn.Module): ''' shape (N, 3, 64, 64) ''' def __init__(self, in_size=3, size=64): super(Discriminator, self).__init__() self.conv = nn.Sequential( nn.Conv2d(in_size, size, 5, 2, 2), nn.LeakyReLU(0.2), conv_bn_relu(size, 2*size), conv_bn_relu(2*size, 4*size), conv_bn_relu(4*size, 8*size), nn.Conv2d(8*size, 1, 4), nn.Sigmoid() ) def forward(self, x): x = self.conv(x) x = x.view(-1) return x
train_vanilla_gan.py
import os import torch import torch.nn as nn from utils import MyDataset from torch.utils.data import DataLoader from network import Generator, Discriminator from tqdm import tqdm import argparse from torchvision.utils import save_image DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") def argsParser(): parser = argparse.ArgumentParser(description='prepare for training.') parser.add_argument('--batch_size', type=int, default=16, help='batch size for training.') parser.add_argument('--n_epoch', type=int, default=100, help='num of epochs for training.') parser.add_argument('--learning_rate', type=float, default=1e-4, help='learning rate for training.') parser.add_argument('--noisy_dim', type=int, default=128, help='dims for noisy.') parser.add_argument('--num_workers', type=int, default=1, help='num of workers for dataloader.') return parser.parse_args() def train(args): gen = Generator(args.noisy_dim) dis = Discriminator() gen = gen.to(DEVICE) dis = dis.to(DEVICE) criterion = nn.BCELoss() gen_opt = torch.optim.Adam(gen.parameters(), lr=args.learning_rate) dis_opt = torch.optim.Adam(dis.parameters(), lr=args.learning_rate) my_dataset = MyDataset(img_path='../faces/', device=DEVICE) dataloader = DataLoader(dataset=my_dataset, batch_size=args.batch_size,shuffle=True, num_workers=args.num_workers) for e in range(args.n_epoch): gen.train() dis.train() total_gen_loss = 0. total_dis_loss = 0. step = 0 for idx, data in enumerate(tqdm(dataloader, desc='Epoch {}: '.format(e))): data = data.to(DEVICE) N, *_ = data.shape noisy = torch.randn(N, args.noisy_dim).to(DEVICE) r_imgs = data r_label = torch.ones((N, )).to(DEVICE) f_label = torch.zeros((N, )).to(DEVICE) f_imgs = gen(noisy) r_logit = dis(r_imgs) f_logit = dis(f_imgs.detach()) # discriminator loss r_loss = criterion(r_logit, r_label) f_loss = criterion(f_logit, f_label) loss_dis = (r_loss + f_loss)/2 total_dis_loss += loss_dis dis_opt.zero_grad() loss_dis.backward() dis_opt.step() # train generator f_logit = dis(f_imgs) loss_gen = criterion(f_logit, r_label) total_gen_loss += loss_gen gen_opt.zero_grad() loss_gen.backward() gen_opt.step() step += 1 gen = gen.eval() noisy = torch.randn(64, args.noisy_dim).to(DEVICE) images = gen(noisy) fname = './my_generated-images-{0:0=4d}.png'.format(e) save_image(images, fname, nrow=8) if not os.path.exists('./logs/{}'.format(e)): os.makedirs('./logs/{}'.format(e), exist_ok=True) if e % 10 == 0: if not os.path.exists('./ckpt'): os.makedirs('./ckpt', exist_ok=True) torch.save(gen.state_dict(), './ckpt/gen_{}.pth'.format(e)) torch.save(dis.state_dict(), './ckpt/dis_{}.pth'.format(e)) if __name__ == '__main__': args = argsParser() train(args)
没有调参,训练50个epoch,模型生成的结果如下:
生成效果不好,后续有时间需要优化一下。