深度学习之 GAN 进行 mnist 图片的生成

深度学习之 GAN 进行 mnist 图片的生成

mport numpy as np
import os
import codecs
import torch
from PIL import Image
import PIL

def get_int(b):
    return int(codecs.encode(b, 'hex'), 16)

def extract_image(path, extract_path):
    with open(path, 'rb') as f:
        data = f.read()
        assert get_int(data[:4]) == 2051
        length = get_int(data[4:8])
        num_rows = get_int(data[8:12])
        num_cols = get_int(data[12:16])
        images = []
        parsed = np.frombuffer(data, dtype=np.uint8, offset=16)
        parsed = parsed.reshape(length, num_rows, num_cols)
        
    for image_i, image in enumerate(parsed):
        Image.fromarray(image, 'L').save(os.path.join(extract_path, 'image_{}.jpg'.format(image_i)))
        

image_path = './mnist/t10k-images.idx3-ubyte'
extract_path = './mnist/data/image'

import math

def images_square_grid(images, mode):
    save_size = math.floor(np.sqrt(images.shape[0]))

    # Scale to 0-255
    images = (((images - images.min()) * 255) / (images.max() - images.min())).astype(np.uint8)

    # Put images in a square arrangement
    images_in_square = np.reshape(
            images[:save_size*save_size],
            (save_size, save_size, images.shape[1], images.shape[2], images.shape[3]))
    if mode == 'L':
        images_in_square = np.squeeze(images_in_square, 4)

    # Combine images to grid image
    new_im = Image.new(mode, (images.shape[1] * save_size, images.shape[2] * save_size))
    for col_i, col_images in enumerate(images_in_square):
        for image_i, image in enumerate(col_images):
            im = Image.fromarray(image, mode)
            new_im.paste(im, (col_i * images.shape[1], image_i * images.shape[2]))

    return new_im

def get_image(image_path, width, height, mode):
    
    image = Image.open(image_path)
    
    if image.size != (width, height):
        face_width = face_width = 108
        j = (image.size[0] - face_width) // 2
        i = (image.size[1] - face_height) // 2

        image = image.crop([j, i, j + face_width, i + face_height])
        image = image.resize([width, height], Image.BILINEAR)
    
    return np.array(image.convert(mode))

def get_batch(image_files, width, height, mode):
    data_batch = np.array([get_image(sample_file, width, height, mode) for sample_file in image_files]).astype(np.float32)
    
    if len(data_batch.shape) < 4:
        data_batch = data_batch.reshape(data_batch.shape + (1,))
    
    return data_batch
  
%matplotlib inline
import os
from glob import glob
from matplotlib import pyplot

data_dir = './mnist/data'
show_n_images = 25

mnist_images = get_batch(glob(os.path.join(data_dir, 'image/*.jpg'))[:show_n_images], 28, 28, 'L')

pyplot.imshow(images_square_grid(mnist_images, 'L'), cmap='gray')


from torch.utils import data
import torchvision as tv


batch_size = 50

transforms = tv.transforms.Compose([
    tv.transforms.Resize(96),
    PIL.ImageOps.grayscale,
    tv.transforms.ToTensor()
])

root="d:\\work\\yoho\\dl\\dl-study\\chapter8\\mnist\\data"

dataset = tv.datasets.ImageFolder(root, transform=transforms)
dataloader = data.DataLoader(dataset, batch_size, shuffle=True, num_workers=1, drop_last=True)


import torch.nn as nn
import torch.optim as optim
from torch.nn.modules import loss
from torch.autograd import Variable as V

class GNet(nn.Module):
    def __init__(self, opt):
        super(GNet, self).__init__()
        
        ngf = opt["ngf"]
        target = opt["target"] or 3
        
        self.main = nn.Sequential(
            nn.ConvTranspose2d( opt["nz"], ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            
            nn.ConvTranspose2d( ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            
            nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            
            nn.ConvTranspose2d( ngf, target, 5, 3, 1, bias=False),
            nn.Tanh()
        )
    
    def forward(self, input):
        return self.main(input)
    
class DNet(nn.Module):
    def __init__(self, opt):
        super(DNet, self).__init__()
        
        ndf = opt["ndf"]
        input = opt["input"] or 3
        
        self.main = nn.Sequential(
            nn.Conv2d(input, ndf, 5, 3, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.3, inplace=True),
            
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )
        
    def forward(self, input):
        return self.main(input).view(-1)
        

lr_g = 0.01
lr_d = 0.01
ngf = 64
ndf = 64
raw_f = 1
nz = 100
d_every = 1
g_every = 5

net_g = GNet({"target": raw_f, "ngf": ngf, 'nz': nz})
net_d = DNet({"input": raw_f, "ndf": ndf})

opt_g = optim.Adam(net_g.parameters(), lr_g, betas=(0.5, 0.999))
opt_d = optim.Adam(net_d.parameters(), lr_g, betas=(0.5, 0.999))

criterion = torch.nn.BCELoss()

true_labels = V(torch.ones(batch_size))
fake_labels = V(torch.zeros(batch_size))
fix_noises = V(torch.randn(batch_size, nz, 1, 1))
noises = V(torch.randn(batch_size, nz, 1, 1))

def train():
    for ii, (img, _) in enumerate(dataloader):
        real_img = V(img)
        
        if (ii + 1) % d_every == 0:
            opt_d.zero_grad()
            output = net_d(real_img)    
            loss_d = criterion(output, true_labels)    
            loss_d.backward()

            noises.data.copy_(torch.randn(batch_size, nz, 1, 1))
            
            fake_img = net_g(noises)
            
            fake_img = fake_img.detach()
            fake_output = net_d(fake_img) 
            loss_fake_d = criterion(fake_output, fake_labels)
            loss_fake_d.backward()

            opt_d.step()


        if (ii + 1) % g_every == 0:
            opt_g.zero_grad()
            noises.data.copy_(torch.randn(batch_size, nz, 1, 1))
            fake_image = net_g(noises)

            fake_output = net_d(fake_img)

            loss_g = criterion(fake_output, true_labels)

            loss_g.backward()
            opt_g.step()


def print_image():
    fix_fake_imgs = net_g(fix_noises)
    fix_fake_imgs = fix_fake_imgs.data.view(batch_size, 96, 96, 1).numpy()
    pyplot.imshow(images_square_grid(fix_fake_imgs, 'L'), cmap='gray')


epochs = 20
def main():
    for i in range(epochs):
        print("epoch {}".format(i))
        train()
        
        if i % 2 == 0:
            print_image()
main()

注意 GAN 很慢,要使用 GPU来工作

posted @ 2018-04-03 13:54  htoooth  阅读(412)  评论(0编辑  收藏  举报