利用GAN来生成手写数字图片
直接贴代码:
首先导入相应的包:
import torch import torchvision import torchvision.transforms as transforms import torch.nn as nn import torch.optim as optim from torchvision.utils import save_image
然后对数据进行预处理:
transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.5],[0.5]) ]) train_set = torchvision.datasets.MNIST( root="./data/mnist", train=True, download=True, transform=transform ) train_loader = torch.utils.data.DataLoader( dataset=train_set, batch_size=8, shuffle=True )
定义网络模型:
class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.conv1 = nn.Sequential( nn.Conv2d(1,32, 5, padding=2), #[32, 28, 28] nn.LeakyReLU(0.2, True), nn.AvgPool2d(2, stride=2) # [32, 14, 14] ) self.conv2 = nn.Sequential( nn.Conv2d(32, 64, 5, padding=2), #[64, 14, 14] nn.LeakyReLU(0.2, True), nn.AvgPool2d(2, stride=2) # [64, 7, 7] ) self.fc = nn.Sequential( nn.Linear(64*7*7, 1024), nn.LeakyReLU(0.2, True), nn.Linear(1024, 1), nn.Sigmoid() ) def forward(self, x): x = self.conv1(x) x = self.conv2(x) x = x.view(x.size(0), -1) # [batch_size, 64*7*7] out = self.fc(x) return out class Generator(nn.Module): def __init__(self, input_size, num_feature): super(Generator, self).__init__() self.fc = nn.Linear(input_size, num_feature) #[100 -> 56*56] self.br = nn.Sequential( nn.BatchNorm2d(1), nn.ReLU(True) ) self.conv1 = nn.Sequential( nn.Conv2d(1, 50, 3, stride=1, padding=1), #[50, 56, 56] nn.BatchNorm2d(50), nn.ReLU(True) ) self.conv2 = nn.Sequential( nn.Conv2d(50, 25, 3, stride=1, padding=1), #[25, 56, 56] nn.BatchNorm2d(25), nn.ReLU(True) ) self.conv3 = nn.Sequential( nn.Conv2d(25, 1, 2, stride=2), #[1, 28, 28] nn.Tanh() ) def forward(self, x): x = self.fc(x) x = x.view(x.size(0), 1, 56, 56) x = self.br(x) x = self.conv1(x) x = self.conv2(x) x = self.conv3(x) return x
设置超参数以及初始化模型:
lr = 3e-4 batch_size = 8 z_dimension = 100
n_epoch = 20 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") G = Generator(z_dimension, 56*56).to(device) D = Discriminator().to(device) G_optimizer = optim.Adam(G.parameters(), lr) D_optimizer = optim.Adam(D.parameters(), lr)
开始训练:
for epoch in range(n_epoch): for (index, data) in enumerate(train_loader): (images,labels) = data real_images = images.to(device) real_labels = torch.ones(batch_size, 1).cuda() fake_labels = torch.zeros(batch_size, 1).cuda() # train Discriminator real_out = D(real_images) D_loss_real = criterion(real_out, real_labels) #generate fake images z = torch.randn(batch_size, z_dimension).cuda() fake_images = G(z) fake_out = D(fake_images) D_loss_fake = criterion(fake_out, fake_labels) D_loss = (D_loss_real + D_loss_fake) / 2.0 D_optimizer.zero_grad() D_loss.backward() D_optimizer.step() #train Generator z = torch.randn(batch_size, z_dimension).cuda() fake_images = G(z) fake_out = D(fake_images) G_loss = criterion(fake_out,real_labels) G_optimizer.zero_grad() G_loss.backward() G_optimizer.step() if (index + 1) % 100 == 0: print("[%d/%d] [%d/%d] G_loss: %.06f D_loss: %.06f" % (epoch+1, n_epoch, index+1, len(train_loader), G_loss.item(), D_loss.item())) z = torch.randn(batch_size, z_dimension).cuda() imgs = G(z) save_image(imgs, './images/images_%d.png' % (epoch + 1))
这是第一个epoch的结果:
这是第10个epoch的结果:
这是第20个epoch的结果:
可以看到生成的图片质量原来越好,即越来越像训练数据中的图片。