DCGAN代码注释

  1 import argparse
  2 import os
  3 import numpy as np
  4 import math
  5 
  6 import torchvision.transforms as transforms
  7 from torchvision.utils import save_image
  8 
  9 from torch.utils.data import DataLoader
 10 from torchvision import datasets
 11 from torch.autograd import Variable
 12 
 13 import torch.nn as nn
 14 import torch.nn.functional as F
 15 import torch
 16 
 17 os.makedirs("images", exist_ok=True)
 18 
 19 #设置参数获取
 20 parser = argparse.ArgumentParser()
 21 parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
 22 parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
 23 parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
 24 parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
 25 parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
 26 parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
 27 parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
 28 parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension")
 29 parser.add_argument("--channels", type=int, default=1, help="number of image channels")
 30 parser.add_argument("--sample_interval", type=int, default=400, help="interval between image sampling")
 31 opt = parser.parse_args()
 32 print(opt)
 33 
 34 #判断能否gpu加速
 35 cuda = True if torch.cuda.is_available() else False
 36 
 37 
 38 #配合nn.apply初始化卷积和标准化的参数
 39 def weights_init_normal(m):
 40     classname = m.__class__.__name__
 41     if classname.find("Conv") != -1:
 42         torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
 43     elif classname.find("BatchNorm2d") != -1:
 44         torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
 45         torch.nn.init.constant_(m.bias.data, 0.0)
 46 
 47 
 48 #生成器
 49 class Generator(nn.Module):
 50     def __init__(self):
 51         super(Generator, self).__init__()
 52 
 53         #图片的长宽各下采样4倍
 54         self.init_size = opt.img_size // 4
 55 
 56         #先一个全连接,把潜空间维度变化到128通道图片的维度
 57         self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2))
 58 
 59         #卷积部分,每次先标准化,然后上采样,然后卷积
 60         #采用leakyReLU和Tanh来满足论文反向求潜空间的需求
 61         #上采样了2次,对应开始的下采样4,这样大小就复原了
 62         self.conv_blocks = nn.Sequential(
 63             nn.BatchNorm2d(128),
 64             nn.Upsample(scale_factor=2),
 65             nn.Conv2d(128, 128, 3, stride=1, padding=1),
 66             nn.BatchNorm2d(128, 0.8),
 67             nn.LeakyReLU(0.2, inplace=True),
 68             nn.Upsample(scale_factor=2),
 69             nn.Conv2d(128, 64, 3, stride=1, padding=1),
 70             nn.BatchNorm2d(64, 0.8),
 71             nn.LeakyReLU(0.2, inplace=True),
 72             nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),
 73             nn.Tanh(),
 74         )
 75 
 76     #前向传播,中间用了下view操作
 77     def forward(self, z):
 78         out = self.l1(z)
 79         out = out.view(out.shape[0], 128, self.init_size, self.init_size)
 80         img = self.conv_blocks(out)
 81         return img
 82 
 83 
 84 #判别器
 85 class Discriminator(nn.Module):
 86     def __init__(self):
 87         super(Discriminator, self).__init__()
 88 
 89         #卷积变化通道,然后用LeakyReLU激活并dropout防止过拟合
 90         def discriminator_block(in_filters, out_filters, bn=True):
 91             block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
 92             if bn:
 93                 block.append(nn.BatchNorm2d(out_filters, 0.8))
 94             return block
 95 
 96         #根据论文,第一块不加标准化操作,来提高效果
 97         self.model = nn.Sequential(
 98             *discriminator_block(opt.channels, 16, bn=False),
 99             *discriminator_block(16, 32),
100             *discriminator_block(32, 64),
101             *discriminator_block(64, 128),
102         )
103 
104         # The height and width of downsampled image
105         #这里要注意运算符优先级,指数优先级更高,因为此处的卷积的步长为2,所以每次维度会降低一般
106         ds_size = opt.img_size // 2 ** 4
107         self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())
108 
109     #前向传播,中间用了下view操作
110     def forward(self, img):
111         out = self.model(img)
112         out = out.view(out.shape[0], -1)
113         validity = self.adv_layer(out)
114 
115         return validity
116 
117 
118 # Loss function
119 adversarial_loss = torch.nn.BCELoss()
120 
121 # Initialize generator and discriminator
122 generator = Generator()
123 discriminator = Discriminator()
124 
125 if cuda:
126     generator.cuda()
127     discriminator.cuda()
128     adversarial_loss.cuda()
129 
130 # Initialize weights
131 generator.apply(weights_init_normal)
132 discriminator.apply(weights_init_normal)
133 
134 # Configure data loader
135 os.makedirs("../../data/mnist", exist_ok=True)
136 dataloader = torch.utils.data.DataLoader(
137     datasets.MNIST(
138         "../../data/mnist",
139         train=True,
140         download=True,
141         transform=transforms.Compose(
142             [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
143         ),
144     ),
145     batch_size=opt.batch_size,
146     shuffle=True,
147 )
148 
149 # Optimizers
150 optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
151 optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
152 
153 Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
154 
155 # ----------
156 #  Training
157 # ----------
158 
159 for epoch in range(opt.n_epochs):
160     for i, (imgs, _) in enumerate(dataloader):
161 
162         # Adversarial ground truths
163         #真实的batch对应的标签全是1
164         valid = Variable(Tensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False)
165         #虚假的batch对应的标签全是0
166         fake = Variable(Tensor(imgs.shape[0], 1).fill_(0.0), requires_grad=False)
167 
168         # Configure input
169         #真实的图片
170         real_imgs = Variable(imgs.type(Tensor))
171 
172         # -----------------
173         #  Train Generator
174         # -----------------
175 
176         optimizer_G.zero_grad()
177 
178         # Sample noise as generator input
179         #batch大小*潜空间维度,作为生成器的输出
180         z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))
181 
182         # Generate a batch of images
183         #生成虚假的图片
184         gen_imgs = generator(z)
185 
186         # Loss measures generator's ability to fool the discriminator
187         #生成器的loss,生成器目标是生成全真(全1),但是判别器给出的不是全1,两者求loss
188         g_loss = adversarial_loss(discriminator(gen_imgs), valid)
189 
190         g_loss.backward()
191         optimizer_G.step()
192 
193         # ---------------------
194         #  Train Discriminator
195         # ---------------------
196 
197         optimizer_D.zero_grad()
198 
199         # Measure discriminator's ability to classify real from generated samples
200         #判别器能否判断出真的图片
201         real_loss = adversarial_loss(discriminator(real_imgs), valid)
202         #判别器能否判断出假的图片
203         fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
204         #两个loss求和
205         d_loss = (real_loss + fake_loss) / 2
206 
207         d_loss.backward()
208         optimizer_D.step()
209 
210         print(
211             "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
212             % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
213         )
214 
215         batches_done = epoch * len(dataloader) + i
216         if batches_done % opt.sample_interval == 0:
217             save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)

 

posted @ 2022-11-15 16:56  IAT14  阅读(65)  评论(0编辑  收藏  举报