DCGAN

# -*- coding: UTF-8 -*-
import torch
import torch.nn as nn
import numpy as np
import torch.nn.init as init
import os
import test
from GAN_model import Generator,Discriminator

print("data loading ...")

G_LR=0.0002
D_LR=0.0002
BATCHSIZE=50
EPOCHES=3000

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

data_pt="data51200.pt"###the position and name of train data(can get it by data_loader.py or data_loader_sketch.py)
train_para_save_path="./pkl/"
loss_save_file = 'loss.txt'


def init_ws_bs(m):
    if isinstance(m,nn.ConvTranspose2d):
        init.normal_(m.weight.data,std=0.2)
        init.normal_(m.bias.data,std=0.2)

g=Generator().to(device)
d=Discriminator().to(device)


init_ws_bs(g),init_ws_bs(d)

###load traind model
# para_path="./pkl/"
# para_file="29g.pkl"
# g=torch.load(para_path+para_file)
# d=torch.load(para_path+para_file)

 
g_optimizer=torch.optim.Adam(g.parameters(),betas=(.5,0.999),lr=G_LR)
d_optimizer=torch.optim.Adam(d.parameters(),betas=(.5,0.999),lr=D_LR)
 
g_loss_func=nn.BCELoss()
d_loss_func=nn.BCELoss()
 
label_real = torch.ones(BATCHSIZE).to(device)
label_fake = torch.zeros(BATCHSIZE).to(device)

if os.path.exists(loss_save_file):
    os.remove(loss_save_file)
if os.path.exists(data_pt):
    real_img=torch.load(data_pt)
if real_img !=None:
    print("load data successfully")
else:
    print("fail to load data")
if not os.path.exists(train_para_save_path):
    os.makedirs(train_para_save_path)
for file in os.listdir(train_para_save_path):
      os.remove(train_para_save_path + file)


print("start training")
batch_imgs=[]
for epoch in range(EPOCHES):

    np.random.shuffle(real_img)
    loss_epoch=[]
    for i in range(len(real_img)):
        batch_imgs.append(real_img[i].numpy())
        if (i+1) % BATCHSIZE == 0:
            batch_real=torch.Tensor(batch_imgs).to(device)

            batch_imgs.clear()

            ####min Discriminate loss
            d_optimizer.zero_grad()
            pre_real=d(batch_real).squeeze()
            # pre_real = d(batch_real)
            d_real_loss=d_loss_func(pre_real,label_real)
            d_real_loss.backward()

            batch_fake=torch.randn(BATCHSIZE,100,1,1).to(device)

            img_fake=g(batch_fake)
            pre_fake=d(img_fake.detach()).squeeze()
            d_fake_loss=d_loss_func(pre_fake,label_fake)
            d_fake_loss.backward()
 
            d_optimizer.step()

            ####min Generate loss
            g_optimizer.zero_grad()
            batch_fake=torch.randn(BATCHSIZE,100,1,1).to(device)

            img_fake=g(batch_fake)
            pre_fake=d(img_fake).squeeze()
            g_loss=g_loss_func(pre_fake,label_real)
            g_loss.backward()
            g_optimizer.step()
            batch_num=i/BATCHSIZE
            print("epoch%d batch%d:"%(epoch,batch_num),(d_real_loss+d_fake_loss).detach().cpu().numpy(),g_loss.detach().cpu().numpy())
            loss_epoch.append([(d_real_loss+d_fake_loss).detach().cpu().numpy(),g_loss.detach().cpu().numpy()])
    ###After finishing an epoch,record the data
    torch.save(g,train_para_save_path+str(epoch)+"g.pkl")
    torch.save(d,train_para_save_path+str(epoch)+"d.pkl")
    with open(loss_save_file, 'a+') as f:
        for d_loss_epoch,g_loss_epoch in loss_epoch:
            f.write(str(d_loss_epoch)+' '+str(g_loss_epoch)+'\n')

    test.draw(train_para_save_path+str(epoch)+"g.pkl",str(epoch))
print("finish the train")

  GAN_model.py

import torch.nn as nn
class Generator(nn.Module):
	def __init__(self):
		super(Generator, self).__init__()
		self.deconv1 = nn.Sequential(#batchsize,100,1,1
			nn.ConvTranspose2d(  # stride(input_w-1)+k-2*Padding
				in_channels=100,
				out_channels=64 * 8,
				kernel_size=4,
				stride=1,
				padding=0,
				bias=False,
			),
			nn.BatchNorm2d(64 * 8),
			nn.ReLU(inplace=True),

		)  # 14
		self.deconv2 = nn.Sequential(
			nn.ConvTranspose2d(  # stride(input_w-1)+k-2*Padding
				in_channels=64 * 8,
				out_channels=64 * 4,
				kernel_size=4,
				stride=2,
				padding=1,
				bias=False,
			),
			nn.BatchNorm2d(64 * 4),
			nn.ReLU(inplace=True),

		)  # 24
		self.deconv3 = nn.Sequential(
			nn.ConvTranspose2d(  # stride(input_w-1)+k-2*Padding
				in_channels=64 * 4,
				out_channels=64 * 2,
				kernel_size=4,
				stride=2,
				padding=1,
				bias=False,
			),
			nn.BatchNorm2d(64 * 2),
			nn.ReLU(inplace=True),

		)  # 48
		self.deconv4 = nn.Sequential(
			nn.ConvTranspose2d(  # stride(input_w-1)+k-2*Padding
				in_channels=64 * 2,
				out_channels=64 * 1,
				kernel_size=4,
				stride=2,
				padding=1,
				bias=False,
			),
			nn.BatchNorm2d(64),
			nn.ReLU(inplace=True),

		)
		self.deconv5 = nn.Sequential(
			nn.ConvTranspose2d(64, 3, 5, 3, 1, bias=False),
			nn.Tanh(),
		)

	def forward(self, x):
		x = self.deconv1(x)
		x = self.deconv2(x)
		x = self.deconv3(x)
		x = self.deconv4(x)
		x = self.deconv5(x)
		return x
class Discriminator(nn.Module):
	def __init__(self):
		super(Discriminator, self).__init__()
		self.conv1 = nn.Sequential(
			nn.Conv2d(  # batchsize,3,96,96
				in_channels=3,
				out_channels=64,
				kernel_size=5,
				padding=1,
				stride=3,
				bias=False,
			),
			nn.BatchNorm2d(64),
			nn.LeakyReLU(.2, inplace=True),

		)
		self.conv2 = nn.Sequential(
			nn.Conv2d(64, 64 * 2, 4, 2, 1, bias=False, ),  # batchsize,16,32,32
			nn.BatchNorm2d(64 * 2),
			nn.LeakyReLU(.2, inplace=True),

		)
		self.conv3 = nn.Sequential(
			nn.Conv2d(64 * 2, 64 * 4, 4, 2, 1, bias=False),
			nn.BatchNorm2d(64 * 4),
			nn.LeakyReLU(.2, inplace=True),

		)
		self.conv4 = nn.Sequential(
			nn.Conv2d(64 * 4, 64 * 8, 4, 2, 1, bias=False),
			nn.BatchNorm2d(64 * 8),
			nn.LeakyReLU(.2, inplace=True),

		)
		self.output = nn.Sequential(
			nn.Conv2d(64 * 8, 1, 4, 1, 0, bias=False),
			nn.Sigmoid()  #
		)

	def forward(self, x):
		x = self.conv1(x)
		x = self.conv2(x)
		x = self.conv3(x)
		x = self.conv4(x)
		x = self.output(x)
		return x

  GAN的精髓在于对抗。生成损失和对抗损失的网络反向传播的方式是一样的,只不过生成损失只更新生成器的参数,判别损失只更新判别器的参数(在优化器里面定义)。

  生成器的训练目标只有一个,让生成的假的图片更像真的:g_loss=g_loss_func(pre_fake,label_real)

  而判别器的目标有两个,让真的更像真的:d_real_loss=d_loss_func(pre_real,label_real)

             让假的更像假的:d_fake_loss=d_loss_func(pre_fake,label_fake)

  

posted on 2019-12-06 15:44  江南烟雨尘  阅读(368)  评论(0编辑  收藏  举报

导航