GAN生成图像,数据集CIFAR-10

实验结果:

Epoch index: 14, 15 epoches in total.

Batch index: 400, the batch size is 64.

Discriminator loss is: 0.9480361938476562, generator loss is: 1.1343715190887451

 Discriminator tells real images real ability: 0.4963718354701996

 Discriminator tells fake images real ability: 0.142733/0.372261

Epoch index: 14, 15 epoches in total.

Batch index: 500, the batch size is 64.

Discriminator loss is: 0.601941704750061, generator loss is: 2.3022451400756836

 Discriminator tells real images real ability: 0.768054187297821

 Discriminator tells fake images real ability: 0.249206/0.123098

Epoch index: 14, 15 epoches in total.

Batch index: 600, the batch size is 64.

Discriminator loss is: 0.9595736265182495, generator loss is: 3.147702217102051

 Discriminator tells real images real ability: 0.8886741399765015

 Discriminator tells fake images real ability: 0.519959/0.0610514

Epoch index: 14, 15 epoches in total.

Batch index: 700, the batch size is 64.

Discriminator loss is: 0.47996288537979126, generator loss is: 2.13641357421875

 Discriminator tells real images real ability: 0.7488793730735779

 Discriminator tells fake images real ability: 0.129179/0.145748

 

Process finished with exit code 0

批量大小64:

 

  1 # 参考自:https://blog.csdn.net/frank_haha/article/details/119894541
  2 import torch.nn as nn
  3 import torch.nn.init as init
  4 import torch.optim
  5 from torch.utils.data import DataLoader
  6 from torchvision.datasets import CIFAR10
  7 import torchvision.transforms as transforms
  8 from torchvision.utils import save_image
  9 
 10 latent_size = 64  # 隐藏层大小
 11 n_channel = 3  # 通道数
 12 n_g_feature = 64
 13 # 生成模型
 14 gnet = nn.Sequential(
 15     nn.ConvTranspose2d(latent_size, 4 * n_g_feature, kernel_size=4, bias=False),  # 卷积
 16     nn.BatchNorm2d(4 * n_g_feature),  # 批量正则化
 17     nn.ReLU(),  # 激活函数
 18 
 19     nn.ConvTranspose2d(4 * n_g_feature, 2 * n_g_feature, kernel_size=4, stride=2, padding=1, bias=False),
 20     nn.BatchNorm2d(2 * n_g_feature),
 21     nn.ReLU(),
 22 
 23     nn.ConvTranspose2d(2 * n_g_feature, n_g_feature, kernel_size=4, stride=2, padding=1, bias=False),
 24     nn.BatchNorm2d(n_g_feature),
 25     nn.ReLU(),
 26 
 27     nn.ConvTranspose2d(n_g_feature, n_channel, kernel_size=4, stride=2, padding=1),
 28     nn.Sigmoid()
 29 )
 30 
 31 # 判别模型
 32 n_d_feature = 64
 33 dnet = nn.Sequential(
 34     nn.Conv2d(n_channel, n_d_feature, kernel_size=4, stride=2, padding=1),
 35     nn.LeakyReLU(0.2),
 36 
 37     nn.Conv2d(n_d_feature, 2 * n_d_feature, kernel_size=4, stride=2, padding=1, bias=False),
 38     nn.BatchNorm2d(2 * n_d_feature),
 39     nn.LeakyReLU(0.2),
 40 
 41     nn.Conv2d(2 * n_d_feature, 4 * n_d_feature, kernel_size=4, stride=2, padding=1, bias=False),
 42     nn.BatchNorm2d(4 * n_d_feature),
 43     nn.LeakyReLU(0.2),
 44 
 45     nn.Conv2d(4 * n_d_feature, 1, kernel_size=4)
 46 )
 47 
 48 
 49 # 权重随机初始化
 50 def weights_init(m):
 51     if type(m) in [nn.ConvTranspose2d, nn.Conv2d]:
 52         init.xavier_normal_(m.weight)
 53     elif type(m) == nn.BatchNorm2d:
 54         init.normal_(m.weight, 1.0, 0.02)
 55         init.constant_(m.bias, 0)
 56 
 57 
 58 gnet.apply(weights_init)
 59 dnet.apply(weights_init)
 60 
 61 # 读取数据集并转换为dataloader类型
 62 dataset = CIFAR10(root='./CIFARdata', download=True, transform=transforms.ToTensor())  # 大小为50000
 63 dataloader = DataLoader(dataset, batch_size=64, shuffle=True)  # 总数据量为50000,批量大小64,因此轮训整个数据集需要782次
 64 
 65 # 定义Loss函数
 66 criterion = nn.BCEWithLogitsLoss()
 67 # 使用adam优化器,根据损失函数计算loss值
 68 goptimizer = torch.optim.Adam(gnet.parameters(), lr=0.0002, betas=(0.5, 0.999))
 69 doptimizer = torch.optim.Adam(dnet.parameters(), lr=0.0002, betas=(0.5, 0.999))
 70 
 71 # batch_size = 128
 72 # 随机初始化,大小为批量大小*隐藏层大小
 73 fixed_noises = torch.randn(dataloader.batch_size, latent_size, 1, 1)
 74 
 75 # 进行15个epoch
 76 epoch_num = 15
 77 for epoch in range(epoch_num):
 78     for batch_idx, data in enumerate(dataloader):  # 每个epoch需要计算782次,每次为一个批量
 79         real_images, _ = data  # 取真实图像信息,real_images为(64,3,32,32)
 80         batch_size = real_images.size(0)  # 取real_images的第0维大小,即批量大小
 81 
 82         labels = torch.ones(batch_size)  # labels为1
 83         preds = dnet(real_images)  # 预测值
 84         outputs = preds.reshape(-1)
 85         dloss_real = criterion(outputs, labels)  # 计算损失
 86         dmean_real = outputs.sigmoid().mean()
 87 
 88         noises = torch.randn(batch_size, latent_size, 1, 1)
 89         fake_images = gnet(noises)  # 生成图像
 90         labels = torch.zeros(batch_size)
 91         fake = fake_images.detach()
 92 
 93         preds = dnet(fake)
 94         outputs = preds.view(-1)
 95         dloss_fake = criterion(outputs, labels)
 96         dmean_fake = outputs.sigmoid().mean()
 97 
 98         dloss = dloss_real + dloss_fake
 99         dnet.zero_grad()
100         dloss.backward()  # 反向梯度
101         doptimizer.step()
102 
103         labels = torch.ones(batch_size)  # 判别生成图像
104         preds = dnet(fake_images)
105         outputs = preds.view(-1)
106         gloss = criterion(outputs, labels)
107         gmean_fake = outputs.sigmoid().mean()
108         gnet.zero_grad()
109         gloss.backward()
110         goptimizer.step()
111 
112         if batch_idx % 350 == 1:  # 生成图像并保存打印
113             fake = gnet(fixed_noises)
114             save_image(fake, f'./GAN_saved02/images_epoch{epoch:02d}_batch{batch_idx:03d}.png')
115 
116             print(f'Epoch index: {epoch}, {epoch_num} epoches in total.')
117             print(f'Batch index: {batch_idx}, the batch size is {batch_size}.')
118             print(f'Discriminator loss is: {dloss}, generator loss is: {gloss}', '\n',
119                   f'Discriminator tells real images real ability: {dmean_real}', '\n',
120                   f'Discriminator tells fake images real ability: {dmean_fake:g}/{gmean_fake:g}')
121 
122 # 保存模型
123 gnet_save_path = 'gnet.pt'
124 torch.save(gnet, gnet_save_path)
125 # gnet = torch.load(gnet_save_path)
126 # gnet.eval()
127 
128 dnet_save_path = 'dnet.pt'
129 torch.save(dnet, dnet_save_path)
130 # dnet = torch.load(dnet_save_path)
131 # dnet.eval()
132 
133 # 测试一下生成图片并保存
134 # for i in range(10):
135 noises = torch.randn(batch_size, latent_size, 1, 1)
136 fake_images = gnet(noises)
137 save_image(fake, f'./test_GAN/1.png')  # {i}
138 
139 # print(gnet, dnet)

 

posted @ 2022-11-22 19:37  silvan_happy  阅读(631)  评论(0编辑  收藏  举报