VAE生成人脸代码

github:https://github.com/shixiaojia/VAE.git

基于VAE介绍的理论,简单实现VAE生成人脸,代码如下:

utils.py

import os
from torch.utils.data import Dataset
from torchvision.transforms import transforms
import glob
import cv2
import numpy as np
import torch


class MyDataset(Dataset):
    def __init__(self, img_path, device):
        super(MyDataset, self).__init__()
        self.device = device
        self.fnames = glob.glob(os.path.join(img_path+"*.jpg"))
        self.transforms = transforms.Compose([
            transforms.ToTensor(),
        ])

    def __getitem__(self, idx):
        fname = self.fnames[idx]
        img = cv2.imread(fname, cv2.IMREAD_COLOR)
        img = self.transforms(img)
        img = img.to(self.device)
        return img

    def __len__(self):
        return len(self.fnames)

 

VAE.py

import torch
import torch.nn as nn


class VAE(nn.Module):
    def __init__(self, image_size: int, in_channels: int, latent_dim: int, hid_dims: int = None):
        super(VAE, self).__init__()
        self.latent_dim = latent_dim
        if not hid_dims:
            hid_dims = [32, 64, 128, 256]

        feature_size = image_size // (2**4)

        modules = []
        for h_d in hid_dims:
            modules.append(nn.Sequential(nn.Conv2d(in_channels, h_d, 3, 2, 1),
                            nn.BatchNorm2d(h_d),
                            nn.LeakyReLU()))
            in_channels = h_d

        self.encoder = nn.Sequential(*modules)
        self.fc_mu = nn.Linear(hid_dims[-1]*feature_size**2, latent_dim)
        self.fc_var = nn.Linear(hid_dims[-1]*feature_size**2, latent_dim)

        # decoder
        self.decoder_input = nn.Linear(latent_dim, hid_dims[-1]*feature_size**2)
        hid_dims.reverse()

        modules = []
        for i in range(len(hid_dims)-1):
            modules.append(nn.Sequential(nn.ConvTranspose2d(hid_dims[i], hid_dims[i+1], 3, 2, 1, 1),
                           nn.BatchNorm2d(hid_dims[i+1]),
                           nn.LeakyReLU()))

        self.decoder = nn.Sequential(*modules)

        self.decoder_out = nn.Sequential(nn.ConvTranspose2d(hid_dims[-1], hid_dims[-1], 3, 2, 1, 1),
                                         nn.BatchNorm2d(hid_dims[-1]),
                                         nn.LeakyReLU(),
                                         nn.Conv2d(hid_dims[-1], 3, 3, 1, 1, 1),
                                         nn.Sigmoid())

    def encode(self, x):
        x = self.encoder(x)
        x = torch.flatten(x, start_dim=1)
        mu = self.fc_mu(x)
        var = self.fc_var(x)
        return mu, var

    def decode(self, x):
        x = self.decoder_input(x)
        x = x.view(-1, 256, 6, 6)
        x = self.decoder(x)
        x = self.decoder_out(x)
        return x

    def re_parameterize(self, mu, log_var):
        std = torch.exp_(0.5*log_var)
        eps = torch.randn_like(std)
        return mu + std*eps

    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.re_parameterize(mu, log_var)
        out = self.decode(z)
        return out, mu, log_var

    def sample(self, n_samples, device):
        z = torch.randn((n_samples, self.latent_dim)).to(device)
        samples = self.decode(z)
        return samples


if __name__ == '__main__':
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    fake_input = torch.ones((1, 3, 96, 96))
    model = VAE(96, 3, 1024)
    out, *_ = model(fake_input)
    print(out.shape)
    print(model.sample(10, DEVICE).shape)

 

Loss.py

import torch
import torch.nn as nn


class Loss(nn.Module):
    def __init__(self, kld_weight=0.03):
        super(Loss, self).__init__()
        self.kld_weight = kld_weight
        self.criterion = nn.MSELoss(reduction='mean')

    def forward(self, input, output, mu, log_var):
        recon_loss = self.criterion(output, input)
        kld_loss = -0.5 * torch.mean(1 + log_var - mu.pow(2) - log_var.exp())
        return recon_loss + self.kld_weight*kld_loss

 

train_vae.py

import os
import numpy as np
import torch
from VAE import VAE
import argparse
from torch.utils.data import DataLoader
from PIL import Image
from torch.optim import Adam
from utils import MyDataset
from torchvision.utils import save_image
from Loss import Loss
from tqdm import tqdm


def args_parser():
    parser = argparse.ArgumentParser(description="Parameters of training vae model")
    parser.add_argument("-b", "--batch_size", type=int, default=128)
    parser.add_argument("-i", "--in_channels", type=int, default=3)
    parser.add_argument("-d", "--latent_dim", type=int, default=256)
    parser.add_argument("-l", "--lr", type=float, default=1e-3)
    parser.add_argument("-w", "--weight_decay", type=float, default=1e-5)
    parser.add_argument("-e", "--epoch", type=int, default=500)
    parser.add_argument("-v", "--snap_epoch", type=int, default=1)
    parser.add_argument("-n", "--num_samples", type=int, default=64)
    parser.add_argument("-p", "--path", type=str, default="./results_linear")
    return parser.parse_args()


def train(model, input_data, loss_fn, optimizer):
    optimizer.zero_grad()
    out, mu, log_var = model(input_data)
    total_loss = loss_fn(input_data, out, mu, log_var)
    total_loss.backward()
    optimizer.step()

    print("loss:", total_loss.item())


if __name__ == '__main__':
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    opt = args_parser()

    loss_fn = Loss(kld_weight=0.03)

    dataset = MyDataset(img_path="../faces/", device=DEVICE)
    train_loader = DataLoader(dataset=dataset, batch_size=opt.batch_size, shuffle=True, num_workers=0)
    model = VAE(image_size=96, in_channels=opt.in_channels, latent_dim=opt.latent_dim)
    model.to(DEVICE)

    optimizer = Adam(model.parameters(), lr=opt.lr, weight_decay=opt.weight_decay)

    for epoch in range(opt.epoch):
        model.train()
        data_bar = tqdm(train_loader)
        for step, data in enumerate(data_bar):
            train(model, data.to(DEVICE), loss_fn, optimizer)

        if epoch % opt.snap_epoch == 0 or epoch == opt.epoch - 1:
            model.eval()
            images = model.sample(opt.num_samples, DEVICE)
            imgs = images.detach().cpu().numpy()
            saved_image_path = os.path.join(opt.path, "images")
            os.makedirs(saved_image_path, exist_ok=True)
            fname = './my_generated-images-{0:0=4d}.png'.format(epoch)
            save_image(images, fname, nrow=8)
            saved_model_path = os.path.join(opt.path, "models")
            os.makedirs(saved_model_path, exist_ok=True)
            torch.save(model.state_dict(), os.path.join(saved_model_path, f"epoch_{epoch}.pth"))

 

 

 

没有调参,训练333个epoch,模型生成的结果如下:

posted @ 2024-06-30 15:43  指间的执着  阅读(6)  评论(0编辑  收藏  举报