DCGAN
# -*- coding: UTF-8 -*- import torch import torch.nn as nn import numpy as np import torch.nn.init as init import os import test from GAN_model import Generator,Discriminator print("data loading ...") G_LR=0.0002 D_LR=0.0002 BATCHSIZE=50 EPOCHES=3000 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") data_pt="data51200.pt"###the position and name of train data(can get it by data_loader.py or data_loader_sketch.py) train_para_save_path="./pkl/" loss_save_file = 'loss.txt' def init_ws_bs(m): if isinstance(m,nn.ConvTranspose2d): init.normal_(m.weight.data,std=0.2) init.normal_(m.bias.data,std=0.2) g=Generator().to(device) d=Discriminator().to(device) init_ws_bs(g),init_ws_bs(d) ###load traind model # para_path="./pkl/" # para_file="29g.pkl" # g=torch.load(para_path+para_file) # d=torch.load(para_path+para_file) g_optimizer=torch.optim.Adam(g.parameters(),betas=(.5,0.999),lr=G_LR) d_optimizer=torch.optim.Adam(d.parameters(),betas=(.5,0.999),lr=D_LR) g_loss_func=nn.BCELoss() d_loss_func=nn.BCELoss() label_real = torch.ones(BATCHSIZE).to(device) label_fake = torch.zeros(BATCHSIZE).to(device) if os.path.exists(loss_save_file): os.remove(loss_save_file) if os.path.exists(data_pt): real_img=torch.load(data_pt) if real_img !=None: print("load data successfully") else: print("fail to load data") if not os.path.exists(train_para_save_path): os.makedirs(train_para_save_path) for file in os.listdir(train_para_save_path): os.remove(train_para_save_path + file) print("start training") batch_imgs=[] for epoch in range(EPOCHES): np.random.shuffle(real_img) loss_epoch=[] for i in range(len(real_img)): batch_imgs.append(real_img[i].numpy()) if (i+1) % BATCHSIZE == 0: batch_real=torch.Tensor(batch_imgs).to(device) batch_imgs.clear() ####min Discriminate loss d_optimizer.zero_grad() pre_real=d(batch_real).squeeze() # pre_real = d(batch_real) d_real_loss=d_loss_func(pre_real,label_real) d_real_loss.backward() batch_fake=torch.randn(BATCHSIZE,100,1,1).to(device) img_fake=g(batch_fake) pre_fake=d(img_fake.detach()).squeeze() d_fake_loss=d_loss_func(pre_fake,label_fake) d_fake_loss.backward() d_optimizer.step() ####min Generate loss g_optimizer.zero_grad() batch_fake=torch.randn(BATCHSIZE,100,1,1).to(device) img_fake=g(batch_fake) pre_fake=d(img_fake).squeeze() g_loss=g_loss_func(pre_fake,label_real) g_loss.backward() g_optimizer.step() batch_num=i/BATCHSIZE print("epoch%d batch%d:"%(epoch,batch_num),(d_real_loss+d_fake_loss).detach().cpu().numpy(),g_loss.detach().cpu().numpy()) loss_epoch.append([(d_real_loss+d_fake_loss).detach().cpu().numpy(),g_loss.detach().cpu().numpy()]) ###After finishing an epoch,record the data torch.save(g,train_para_save_path+str(epoch)+"g.pkl") torch.save(d,train_para_save_path+str(epoch)+"d.pkl") with open(loss_save_file, 'a+') as f: for d_loss_epoch,g_loss_epoch in loss_epoch: f.write(str(d_loss_epoch)+' '+str(g_loss_epoch)+'\n') test.draw(train_para_save_path+str(epoch)+"g.pkl",str(epoch)) print("finish the train")
GAN_model.py
import torch.nn as nn class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() self.deconv1 = nn.Sequential(#batchsize,100,1,1 nn.ConvTranspose2d( # stride(input_w-1)+k-2*Padding in_channels=100, out_channels=64 * 8, kernel_size=4, stride=1, padding=0, bias=False, ), nn.BatchNorm2d(64 * 8), nn.ReLU(inplace=True), ) # 14 self.deconv2 = nn.Sequential( nn.ConvTranspose2d( # stride(input_w-1)+k-2*Padding in_channels=64 * 8, out_channels=64 * 4, kernel_size=4, stride=2, padding=1, bias=False, ), nn.BatchNorm2d(64 * 4), nn.ReLU(inplace=True), ) # 24 self.deconv3 = nn.Sequential( nn.ConvTranspose2d( # stride(input_w-1)+k-2*Padding in_channels=64 * 4, out_channels=64 * 2, kernel_size=4, stride=2, padding=1, bias=False, ), nn.BatchNorm2d(64 * 2), nn.ReLU(inplace=True), ) # 48 self.deconv4 = nn.Sequential( nn.ConvTranspose2d( # stride(input_w-1)+k-2*Padding in_channels=64 * 2, out_channels=64 * 1, kernel_size=4, stride=2, padding=1, bias=False, ), nn.BatchNorm2d(64), nn.ReLU(inplace=True), ) self.deconv5 = nn.Sequential( nn.ConvTranspose2d(64, 3, 5, 3, 1, bias=False), nn.Tanh(), ) def forward(self, x): x = self.deconv1(x) x = self.deconv2(x) x = self.deconv3(x) x = self.deconv4(x) x = self.deconv5(x) return x class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.conv1 = nn.Sequential( nn.Conv2d( # batchsize,3,96,96 in_channels=3, out_channels=64, kernel_size=5, padding=1, stride=3, bias=False, ), nn.BatchNorm2d(64), nn.LeakyReLU(.2, inplace=True), ) self.conv2 = nn.Sequential( nn.Conv2d(64, 64 * 2, 4, 2, 1, bias=False, ), # batchsize,16,32,32 nn.BatchNorm2d(64 * 2), nn.LeakyReLU(.2, inplace=True), ) self.conv3 = nn.Sequential( nn.Conv2d(64 * 2, 64 * 4, 4, 2, 1, bias=False), nn.BatchNorm2d(64 * 4), nn.LeakyReLU(.2, inplace=True), ) self.conv4 = nn.Sequential( nn.Conv2d(64 * 4, 64 * 8, 4, 2, 1, bias=False), nn.BatchNorm2d(64 * 8), nn.LeakyReLU(.2, inplace=True), ) self.output = nn.Sequential( nn.Conv2d(64 * 8, 1, 4, 1, 0, bias=False), nn.Sigmoid() # ) def forward(self, x): x = self.conv1(x) x = self.conv2(x) x = self.conv3(x) x = self.conv4(x) x = self.output(x) return x
GAN的精髓在于对抗。生成损失和对抗损失的网络反向传播的方式是一样的,只不过生成损失只更新生成器的参数,判别损失只更新判别器的参数(在优化器里面定义)。
生成器的训练目标只有一个,让生成的假的图片更像真的:g_loss=g_loss_func(pre_fake,label_real)
而判别器的目标有两个,让真的更像真的:d_real_loss=d_loss_func(pre_real,label_real)
让假的更像假的:d_fake_loss=d_loss_func(pre_fake,label_fake)