GAN生成图像,数据集CIFAR-10
实验结果:
Epoch index: 14, 15 epoches in total.
Batch index: 400, the batch size is 64.
Discriminator loss is: 0.9480361938476562, generator loss is: 1.1343715190887451
Discriminator tells real images real ability: 0.4963718354701996
Discriminator tells fake images real ability: 0.142733/0.372261
Epoch index: 14, 15 epoches in total.
Batch index: 500, the batch size is 64.
Discriminator loss is: 0.601941704750061, generator loss is: 2.3022451400756836
Discriminator tells real images real ability: 0.768054187297821
Discriminator tells fake images real ability: 0.249206/0.123098
Epoch index: 14, 15 epoches in total.
Batch index: 600, the batch size is 64.
Discriminator loss is: 0.9595736265182495, generator loss is: 3.147702217102051
Discriminator tells real images real ability: 0.8886741399765015
Discriminator tells fake images real ability: 0.519959/0.0610514
Epoch index: 14, 15 epoches in total.
Batch index: 700, the batch size is 64.
Discriminator loss is: 0.47996288537979126, generator loss is: 2.13641357421875
Discriminator tells real images real ability: 0.7488793730735779
Discriminator tells fake images real ability: 0.129179/0.145748
Process finished with exit code 0
批量大小64:
1 # 参考自:https://blog.csdn.net/frank_haha/article/details/119894541 2 import torch.nn as nn 3 import torch.nn.init as init 4 import torch.optim 5 from torch.utils.data import DataLoader 6 from torchvision.datasets import CIFAR10 7 import torchvision.transforms as transforms 8 from torchvision.utils import save_image 9 10 latent_size = 64 # 隐藏层大小 11 n_channel = 3 # 通道数 12 n_g_feature = 64 13 # 生成模型 14 gnet = nn.Sequential( 15 nn.ConvTranspose2d(latent_size, 4 * n_g_feature, kernel_size=4, bias=False), # 卷积 16 nn.BatchNorm2d(4 * n_g_feature), # 批量正则化 17 nn.ReLU(), # 激活函数 18 19 nn.ConvTranspose2d(4 * n_g_feature, 2 * n_g_feature, kernel_size=4, stride=2, padding=1, bias=False), 20 nn.BatchNorm2d(2 * n_g_feature), 21 nn.ReLU(), 22 23 nn.ConvTranspose2d(2 * n_g_feature, n_g_feature, kernel_size=4, stride=2, padding=1, bias=False), 24 nn.BatchNorm2d(n_g_feature), 25 nn.ReLU(), 26 27 nn.ConvTranspose2d(n_g_feature, n_channel, kernel_size=4, stride=2, padding=1), 28 nn.Sigmoid() 29 ) 30 31 # 判别模型 32 n_d_feature = 64 33 dnet = nn.Sequential( 34 nn.Conv2d(n_channel, n_d_feature, kernel_size=4, stride=2, padding=1), 35 nn.LeakyReLU(0.2), 36 37 nn.Conv2d(n_d_feature, 2 * n_d_feature, kernel_size=4, stride=2, padding=1, bias=False), 38 nn.BatchNorm2d(2 * n_d_feature), 39 nn.LeakyReLU(0.2), 40 41 nn.Conv2d(2 * n_d_feature, 4 * n_d_feature, kernel_size=4, stride=2, padding=1, bias=False), 42 nn.BatchNorm2d(4 * n_d_feature), 43 nn.LeakyReLU(0.2), 44 45 nn.Conv2d(4 * n_d_feature, 1, kernel_size=4) 46 ) 47 48 49 # 权重随机初始化 50 def weights_init(m): 51 if type(m) in [nn.ConvTranspose2d, nn.Conv2d]: 52 init.xavier_normal_(m.weight) 53 elif type(m) == nn.BatchNorm2d: 54 init.normal_(m.weight, 1.0, 0.02) 55 init.constant_(m.bias, 0) 56 57 58 gnet.apply(weights_init) 59 dnet.apply(weights_init) 60 61 # 读取数据集并转换为dataloader类型 62 dataset = CIFAR10(root='./CIFARdata', download=True, transform=transforms.ToTensor()) # 大小为50000 63 dataloader = DataLoader(dataset, batch_size=64, shuffle=True) # 总数据量为50000,批量大小64,因此轮训整个数据集需要782次 64 65 # 定义Loss函数 66 criterion = nn.BCEWithLogitsLoss() 67 # 使用adam优化器,根据损失函数计算loss值 68 goptimizer = torch.optim.Adam(gnet.parameters(), lr=0.0002, betas=(0.5, 0.999)) 69 doptimizer = torch.optim.Adam(dnet.parameters(), lr=0.0002, betas=(0.5, 0.999)) 70 71 # batch_size = 128 72 # 随机初始化,大小为批量大小*隐藏层大小 73 fixed_noises = torch.randn(dataloader.batch_size, latent_size, 1, 1) 74 75 # 进行15个epoch 76 epoch_num = 15 77 for epoch in range(epoch_num): 78 for batch_idx, data in enumerate(dataloader): # 每个epoch需要计算782次,每次为一个批量 79 real_images, _ = data # 取真实图像信息,real_images为(64,3,32,32) 80 batch_size = real_images.size(0) # 取real_images的第0维大小,即批量大小 81 82 labels = torch.ones(batch_size) # labels为1 83 preds = dnet(real_images) # 预测值 84 outputs = preds.reshape(-1) 85 dloss_real = criterion(outputs, labels) # 计算损失 86 dmean_real = outputs.sigmoid().mean() 87 88 noises = torch.randn(batch_size, latent_size, 1, 1) 89 fake_images = gnet(noises) # 生成图像 90 labels = torch.zeros(batch_size) 91 fake = fake_images.detach() 92 93 preds = dnet(fake) 94 outputs = preds.view(-1) 95 dloss_fake = criterion(outputs, labels) 96 dmean_fake = outputs.sigmoid().mean() 97 98 dloss = dloss_real + dloss_fake 99 dnet.zero_grad() 100 dloss.backward() # 反向梯度 101 doptimizer.step() 102 103 labels = torch.ones(batch_size) # 判别生成图像 104 preds = dnet(fake_images) 105 outputs = preds.view(-1) 106 gloss = criterion(outputs, labels) 107 gmean_fake = outputs.sigmoid().mean() 108 gnet.zero_grad() 109 gloss.backward() 110 goptimizer.step() 111 112 if batch_idx % 350 == 1: # 生成图像并保存打印 113 fake = gnet(fixed_noises) 114 save_image(fake, f'./GAN_saved02/images_epoch{epoch:02d}_batch{batch_idx:03d}.png') 115 116 print(f'Epoch index: {epoch}, {epoch_num} epoches in total.') 117 print(f'Batch index: {batch_idx}, the batch size is {batch_size}.') 118 print(f'Discriminator loss is: {dloss}, generator loss is: {gloss}', '\n', 119 f'Discriminator tells real images real ability: {dmean_real}', '\n', 120 f'Discriminator tells fake images real ability: {dmean_fake:g}/{gmean_fake:g}') 121 122 # 保存模型 123 gnet_save_path = 'gnet.pt' 124 torch.save(gnet, gnet_save_path) 125 # gnet = torch.load(gnet_save_path) 126 # gnet.eval() 127 128 dnet_save_path = 'dnet.pt' 129 torch.save(dnet, dnet_save_path) 130 # dnet = torch.load(dnet_save_path) 131 # dnet.eval() 132 133 # 测试一下生成图片并保存 134 # for i in range(10): 135 noises = torch.randn(batch_size, latent_size, 1, 1) 136 fake_images = gnet(noises) 137 save_image(fake, f'./test_GAN/1.png') # {i} 138 139 # print(gnet, dnet)