GAN生成手写数字MNIST

  1 # !/usr/bin/python
  2 # -*- coding: UTF-8 -*-
  3 # 参考自:https://blog.csdn.net/qq_39395755/article/details/109305055?spm=1001.2014.3001.5502
  4  
  5 import torch
  6 import torch.nn as nn
  7 import torch.nn.functional as func
  8 import torchvision
  9 import matplotlib.pylab as plt
 10 import numpy as np
 11  
 12  
 13  
 14 batch_size = 160
 15  
 16 # 将读取的图片转换为tensor 并标准化
 17 transform = torchvision.transforms.Compose([
 18     torchvision.transforms.ToTensor(),
 19     torchvision.transforms.Normalize(mean=[0.5], std=[0.5])
 20 ])
 21 """
 22 ToTensor()能够把灰度范围从0-255变换到0-1之间,而后面的transform.Normalize()则把0-1变换到(-1,1).具体地说,对每个通道而言,Normalize执行以下操作:
 23 image=(image-mean)/std
 24 其中mean和std分别通过(0.5,0.5,0.5)和(0.5,0.5,0.5)进行指定。原来的0-1最小值0则变成(0-0.5)/0.5=-1,而最大值1则变成(1-0.5)/0.5=1.
 25 因为用到的mnist为灰度图 单通道 所以mean和std 只用了一个值
 26 数据如果分布在(0,1)之间,可能实际的bias,就是神经网络的输入b会比较大,而模型初始化时b=0的,这样会导致神经网络收敛比较慢,经过Normalize后,可以加快模型的收敛速度。
 27 因为对RGB图片而言,数据范围是[0-255]的,需要先经过ToTensor除以255归一化到[0,1]之后,再通过Normalize计算过后,将数据归一化到[-1,1]。
 28 """
 29 
 30 dataset = torchvision.datasets.MNIST("./MNIST", train=True, transform=transform, download=True)# 如果MNIST数据集有的话,download可以设置为False
 31 data_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size = batch_size, shuffle=True) # shuffle乱序
 32 device = torch.device("cuda")
 33  
 34 def denomalize(x):
 35     """还原被标准化后的图像"""
 36     out = (x+1) / 2
 37     out = out.view(32, 28, 28).unsqueeze(1) # 添加channel
 38     return out.clamp(0,1)
 39  
 40 def imshow(img, epoch):
 41     """打印生成器产生的图片"""
 42     # torchvision.utils.make_grid用来连接一组图, img为一个tensor(batch, channel, height, weight)
 43     # .detach()消除梯度
 44     im = torchvision.utils.make_grid(img, nrow=8).detach().numpy()
 45     # print(np.shape(im))
 46     plt.title("Epoch on %d" % (epoch+1))
 47     plt.imshow(im.transpose(1, 2, 0)) # 调整图形标签, plt的图片格式为(height, weight, channel)
 48     plt.savefig('./save_pic/%d.jpg' % (epoch+1))
 49     plt.show()
 50  
 51 # 判别器模型
 52 class DNet(nn.Module):
 53     def __init__(self):
 54         super(DNet, self).__init__()
 55  
 56         self.l1 = nn.Linear(28*28, 256)
 57         self.a = nn.ReLU()
 58         self.l2 = nn.Linear(256, 128)
 59         self.l3 = nn.Linear(128, 1)
 60         self.s = nn.Sigmoid()
 61  
 62  
 63     def forward(self, x):
 64         x = self.l1(x)
 65         x = self.a(x)
 66         x = self.l2(x)
 67         x = self.a(x)
 68         x = self.l3(x)
 69         x = self.s(x)
 70  
 71         return x
 72  
 73  
 74 # 生成器模型
 75 class GNet(nn.Module):
 76     def __init__(self):
 77         super(GNet, self).__init__()
 78  
 79         self.l1 = nn.Linear(10, 128)
 80         self.a = nn.ReLU()
 81         self.l2 = nn.Linear(128, 256)
 82         self.l3 = nn.Linear(256, 28*28)
 83         self.tanh = nn.Tanh()
 84  
 85     def forward(self, x):
 86         x = self.l1(x)
 87         x = self.a(x)
 88         x = self.l2(x)
 89         x = self.a(x)
 90         x = self.l3(x)
 91         x = self.tanh(x)
 92  
 93         return x
 94  
 95 # 构建模型并送入GPU
 96 D = DNet().to(device)
 97 G = GNet().to(device)
 98  
 99 print(D)
100 print(G)
101  
102 # 设置优化器
103 D_optimizer = torch.optim.Adam(D.parameters(), lr=0.001)
104 G_optimizer = torch.optim.Adam(G.parameters(), lr=0.001)
105  
106  
107  
108  
109 for epoch in range(250):
110     cerrent = 0.0 # 正确识别
111     for step, data in enumerate(data_loader):
112         # 获取真实图集 并拉直
113         real_images = data[0].reshape(batch_size, -1).to(device)
114  
115         # 构造真假标签
116         real_labels = torch.ones(batch_size, 1).to(device)
117         fake_labels = torch.zeros(batch_size, 1).to(device)
118  
119         # 训练辨别器 分别将真图片和真标签喂入判别器、生成图和假标签喂入判别器
120         # 判别器的损失为真假训练的损失和
121         # print(real_images.size())
122         real_outputs = D(real_images)
123         real_loss = func.binary_cross_entropy(real_outputs, real_labels)
124  
125         z = torch.randn(batch_size, 10).to(device) # 用生成器产生fake图喂入判别器网络
126         fake_images = G(z)
127         d_fake_outputs = D(fake_images)
128         fake_loss = func.binary_cross_entropy(d_fake_outputs, fake_labels)
129  
130         d_loss = real_loss + fake_loss
131         G_optimizer.zero_grad()
132         D_optimizer.zero_grad()
133         d_loss.backward()
134         D_optimizer.step()
135  
136         # 训练生成器
137         z = torch.randn(batch_size, 10).to(device)
138         fake_images = G(z)
139         fake_outputs = D(fake_images)
140         g_loss = func.binary_cross_entropy(fake_outputs, real_labels) # 将fake图和真标签喂入判别器, 当g_loss越小生成越真实
141         G_optimizer.zero_grad()
142         D_optimizer.zero_grad()
143         g_loss.backward()
144         G_optimizer.step()
145  
146  
147         if step % 20 == 19:
148             print("epoch: " , epoch+1, "  step: ", step+1, "   d_loss: %.4f" % d_loss.mean().item(),
149                   "   g_loss: %.4f" % g_loss.mean().item(), "   d_acc:  %.4f" %  real_outputs.mean().item(),
150                   "   d(g)_acc: %.4f" % d_fake_outputs.mean().item())
151     # 每10个epoch 进行一次生成
152     if epoch % 10 == 9:
153         z = torch.randn(32, 10).to(device)
154         img = G(z)
155  
156         imshow(denomalize(img.to("cpu")), epoch)
157  
158  
159  
160 torch.save(D.state_dict(), "./D.pth")
161 torch.save(G.state_dict(), "./G.pth")

250个epoch后:

 

posted @ 2022-11-19 17:33  silvan_happy  阅读(73)  评论(0编辑  收藏  举报