DCGAN代码注释
1 import argparse 2 import os 3 import numpy as np 4 import math 5 6 import torchvision.transforms as transforms 7 from torchvision.utils import save_image 8 9 from torch.utils.data import DataLoader 10 from torchvision import datasets 11 from torch.autograd import Variable 12 13 import torch.nn as nn 14 import torch.nn.functional as F 15 import torch 16 17 os.makedirs("images", exist_ok=True) 18 19 #设置参数获取 20 parser = argparse.ArgumentParser() 21 parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training") 22 parser.add_argument("--batch_size", type=int, default=64, help="size of the batches") 23 parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate") 24 parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient") 25 parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient") 26 parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation") 27 parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space") 28 parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension") 29 parser.add_argument("--channels", type=int, default=1, help="number of image channels") 30 parser.add_argument("--sample_interval", type=int, default=400, help="interval between image sampling") 31 opt = parser.parse_args() 32 print(opt) 33 34 #判断能否gpu加速 35 cuda = True if torch.cuda.is_available() else False 36 37 38 #配合nn.apply初始化卷积和标准化的参数 39 def weights_init_normal(m): 40 classname = m.__class__.__name__ 41 if classname.find("Conv") != -1: 42 torch.nn.init.normal_(m.weight.data, 0.0, 0.02) 43 elif classname.find("BatchNorm2d") != -1: 44 torch.nn.init.normal_(m.weight.data, 1.0, 0.02) 45 torch.nn.init.constant_(m.bias.data, 0.0) 46 47 48 #生成器 49 class Generator(nn.Module): 50 def __init__(self): 51 super(Generator, self).__init__() 52 53 #图片的长宽各下采样4倍 54 self.init_size = opt.img_size // 4 55 56 #先一个全连接,把潜空间维度变化到128通道图片的维度 57 self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2)) 58 59 #卷积部分,每次先标准化,然后上采样,然后卷积 60 #采用leakyReLU和Tanh来满足论文反向求潜空间的需求 61 #上采样了2次,对应开始的下采样4,这样大小就复原了 62 self.conv_blocks = nn.Sequential( 63 nn.BatchNorm2d(128), 64 nn.Upsample(scale_factor=2), 65 nn.Conv2d(128, 128, 3, stride=1, padding=1), 66 nn.BatchNorm2d(128, 0.8), 67 nn.LeakyReLU(0.2, inplace=True), 68 nn.Upsample(scale_factor=2), 69 nn.Conv2d(128, 64, 3, stride=1, padding=1), 70 nn.BatchNorm2d(64, 0.8), 71 nn.LeakyReLU(0.2, inplace=True), 72 nn.Conv2d(64, opt.channels, 3, stride=1, padding=1), 73 nn.Tanh(), 74 ) 75 76 #前向传播,中间用了下view操作 77 def forward(self, z): 78 out = self.l1(z) 79 out = out.view(out.shape[0], 128, self.init_size, self.init_size) 80 img = self.conv_blocks(out) 81 return img 82 83 84 #判别器 85 class Discriminator(nn.Module): 86 def __init__(self): 87 super(Discriminator, self).__init__() 88 89 #卷积变化通道,然后用LeakyReLU激活并dropout防止过拟合 90 def discriminator_block(in_filters, out_filters, bn=True): 91 block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)] 92 if bn: 93 block.append(nn.BatchNorm2d(out_filters, 0.8)) 94 return block 95 96 #根据论文,第一块不加标准化操作,来提高效果 97 self.model = nn.Sequential( 98 *discriminator_block(opt.channels, 16, bn=False), 99 *discriminator_block(16, 32), 100 *discriminator_block(32, 64), 101 *discriminator_block(64, 128), 102 ) 103 104 # The height and width of downsampled image 105 #这里要注意运算符优先级,指数优先级更高,因为此处的卷积的步长为2,所以每次维度会降低一般 106 ds_size = opt.img_size // 2 ** 4 107 self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid()) 108 109 #前向传播,中间用了下view操作 110 def forward(self, img): 111 out = self.model(img) 112 out = out.view(out.shape[0], -1) 113 validity = self.adv_layer(out) 114 115 return validity 116 117 118 # Loss function 119 adversarial_loss = torch.nn.BCELoss() 120 121 # Initialize generator and discriminator 122 generator = Generator() 123 discriminator = Discriminator() 124 125 if cuda: 126 generator.cuda() 127 discriminator.cuda() 128 adversarial_loss.cuda() 129 130 # Initialize weights 131 generator.apply(weights_init_normal) 132 discriminator.apply(weights_init_normal) 133 134 # Configure data loader 135 os.makedirs("../../data/mnist", exist_ok=True) 136 dataloader = torch.utils.data.DataLoader( 137 datasets.MNIST( 138 "../../data/mnist", 139 train=True, 140 download=True, 141 transform=transforms.Compose( 142 [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])] 143 ), 144 ), 145 batch_size=opt.batch_size, 146 shuffle=True, 147 ) 148 149 # Optimizers 150 optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) 151 optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) 152 153 Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor 154 155 # ---------- 156 # Training 157 # ---------- 158 159 for epoch in range(opt.n_epochs): 160 for i, (imgs, _) in enumerate(dataloader): 161 162 # Adversarial ground truths 163 #真实的batch对应的标签全是1 164 valid = Variable(Tensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False) 165 #虚假的batch对应的标签全是0 166 fake = Variable(Tensor(imgs.shape[0], 1).fill_(0.0), requires_grad=False) 167 168 # Configure input 169 #真实的图片 170 real_imgs = Variable(imgs.type(Tensor)) 171 172 # ----------------- 173 # Train Generator 174 # ----------------- 175 176 optimizer_G.zero_grad() 177 178 # Sample noise as generator input 179 #batch大小*潜空间维度,作为生成器的输出 180 z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim)))) 181 182 # Generate a batch of images 183 #生成虚假的图片 184 gen_imgs = generator(z) 185 186 # Loss measures generator's ability to fool the discriminator 187 #生成器的loss,生成器目标是生成全真(全1),但是判别器给出的不是全1,两者求loss 188 g_loss = adversarial_loss(discriminator(gen_imgs), valid) 189 190 g_loss.backward() 191 optimizer_G.step() 192 193 # --------------------- 194 # Train Discriminator 195 # --------------------- 196 197 optimizer_D.zero_grad() 198 199 # Measure discriminator's ability to classify real from generated samples 200 #判别器能否判断出真的图片 201 real_loss = adversarial_loss(discriminator(real_imgs), valid) 202 #判别器能否判断出假的图片 203 fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake) 204 #两个loss求和 205 d_loss = (real_loss + fake_loss) / 2 206 207 d_loss.backward() 208 optimizer_D.step() 209 210 print( 211 "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" 212 % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item()) 213 ) 214 215 batches_done = epoch * len(dataloader) + i 216 if batches_done % opt.sample_interval == 0: 217 save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)
心之所动 且就随缘去吧