利用GAN来生成手写数字图片

直接贴代码:

首先导入相应的包:

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
from torchvision.utils import save_image 

然后对数据进行预处理:

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5],[0.5])
])
train_set = torchvision.datasets.MNIST(
    root="./data/mnist",
    train=True,
    download=True,
    transform=transform
)
train_loader = torch.utils.data.DataLoader(
    dataset=train_set,
    batch_size=8,
    shuffle=True
)

定义网络模型:

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(1,32, 5, padding=2), #[32, 28, 28]
            nn.LeakyReLU(0.2, True),
            nn.AvgPool2d(2, stride=2) # [32, 14, 14]
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, 5, padding=2), #[64, 14, 14]
            nn.LeakyReLU(0.2, True),
            nn.AvgPool2d(2, stride=2) # [64, 7, 7]
        )
        self.fc = nn.Sequential(
            nn.Linear(64*7*7, 1024),
            nn.LeakyReLU(0.2, True),
            nn.Linear(1024, 1),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1) # [batch_size, 64*7*7]
        out = self.fc(x)
        return out

class Generator(nn.Module):
    def __init__(self, input_size, num_feature):
        super(Generator, self).__init__()
        self.fc = nn.Linear(input_size, num_feature) #[100 -> 56*56]
        self.br = nn.Sequential(
            nn.BatchNorm2d(1),
            nn.ReLU(True)
        )
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 50, 3, stride=1, padding=1), #[50, 56, 56]
            nn.BatchNorm2d(50),
            nn.ReLU(True)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(50, 25, 3, stride=1, padding=1), #[25, 56, 56]
            nn.BatchNorm2d(25),
            nn.ReLU(True)
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(25, 1, 2, stride=2), #[1, 28, 28]
            nn.Tanh()
        )
        
    def forward(self, x):
        x = self.fc(x)
        x = x.view(x.size(0), 1, 56, 56)
        x = self.br(x)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        
        return x

设置超参数以及初始化模型:

lr = 3e-4
batch_size = 8
z_dimension = 100
n_epoch = 20 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") G = Generator(z_dimension, 56*56).to(device) D = Discriminator().to(device) G_optimizer = optim.Adam(G.parameters(), lr) D_optimizer = optim.Adam(D.parameters(), lr)

开始训练:

for epoch in range(n_epoch):
    for (index, data) in enumerate(train_loader):
        (images,labels) = data
        real_images = images.to(device)
        real_labels = torch.ones(batch_size, 1).cuda()
        fake_labels = torch.zeros(batch_size, 1).cuda()
        # train Discriminator
        real_out = D(real_images)
        D_loss_real = criterion(real_out, real_labels)
        #generate fake images
        z = torch.randn(batch_size, z_dimension).cuda()
        fake_images = G(z)
        fake_out = D(fake_images)
        D_loss_fake = criterion(fake_out, fake_labels)
        D_loss = (D_loss_real + D_loss_fake) / 2.0
        D_optimizer.zero_grad()
        D_loss.backward()
        D_optimizer.step()
        
        #train Generator
        z = torch.randn(batch_size, z_dimension).cuda()
        fake_images = G(z)
        fake_out = D(fake_images)
        G_loss = criterion(fake_out,real_labels)
        G_optimizer.zero_grad()
        G_loss.backward()
        G_optimizer.step()
        
        if (index + 1) % 100 == 0:
            print("[%d/%d] [%d/%d] G_loss: %.06f D_loss: %.06f"
                  % (epoch+1, n_epoch, index+1, len(train_loader), G_loss.item(), D_loss.item()))
    z = torch.randn(batch_size, z_dimension).cuda()
    imgs = G(z)
    save_image(imgs, './images/images_%d.png' % (epoch + 1))

这是第一个epoch的结果:

这是第10个epoch的结果:

这是第20个epoch的结果:

 

可以看到生成的图片质量原来越好,即越来越像训练数据中的图片。

posted @ 2020-03-21 23:22  liualex_sone  阅读(427)  评论(0编辑  收藏  举报