在生命的旅途中,要诚挚地珍惜,要深深地疼爱。在生命的旅途中,要努力地追寻,也要保持静观。在生命的旅途中,要有所敬畏,也要有|

机智的程序DOG

园龄:5年4个月粉丝:3关注:0

Conditional GAN代码实现(Pytorch)

论文地址:
https://arxiv.org/abs/1411.1784

1. 提出的背景

  1. 传统的GAN虽然可以生成图像,但是无法控制具体生成图像的种类。例如在生成手写体时,GAN和DCGAN都可以生成0-9这十个数字,但是用户无法指定具体生成那个数字的图像;
  2. GAN和DCGAN存在模式崩塌现象(Mode collapse(模式坍塌))。

2. 主要思想

  • GAN主要包括两个网络,一个生成器和一个判别器,GAN的主要优化函数是

minGmaxDV(D,G)=Expdata(x)[logD(x)]+Ezpdata(z)[log(1D(G(z)))]
CGAN与传统的GAN相比,区别就是增加了标签作为训练的一个输入,CGAN的优化函数为
minGmaxDV(D,G)=Exp data (x)[logD(xy)]+Ezp data (z)[log(1D(G(zy)))]

  • 结构图:

3. 具体实现

  1. 生成器
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        # 两个128 x 7 x 7 cat后依然为256 x 7 x 7
        self.linear1 = nn.Sequential(
            nn.Linear(100, 128 * 7 * 7),
            nn.ReLU(),
            nn.BatchNorm1d(128 * 7 * 7),
        )
        self.linear2 = nn.Sequential(
            nn.Linear(10, 128 * 7 * 7),
            nn.ReLU(),
            nn.BatchNorm1d(128 * 7 * 7),
        )

        self.model = nn.Sequential(
            # 128 x 7 x 7
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            # 64 x 14 x 14
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            # 1 x 28 x 28
            nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x, c):
        x = self.linear1(x)
        x = x.view(-1, 128, 7, 7)
        c = self.linear2(c)
        c = c.view(-1, 128, 7, 7)
        # 256 x 7 x 7
        # 在channels方面合并
        x = torch.cat([x, c], dim=1)
        return self.model(x)

判别器接收两个输入,一个是随机噪声,一个是标签,将噪声和标签转换为长度128x7x7的向量,再将两个向量连接起来,构成一个256x7x7的向量,再进行三次的转置卷积,最终输出一个1x28x28(与mnist数据集的大小保持一致)的图像。

  1. 判别器
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        # input: 1 x 28 x 28 + 10 condition
        self.linear = nn.Sequential(
            nn.Linear(10, 1 * 28 * 28),
            nn.ReLU()
        )
        self.model = nn.Sequential(
            nn.Conv2d(2, 64, kernel_size=3, stride=2),
            nn.LeakyReLU(),
            nn.Dropout(p=0.3),
            nn.Conv2d(64, 128, kernel_size=3, stride=2),
            nn.LeakyReLU(),
            nn.Dropout(p=0.3),
            nn.BatchNorm2d(128),
        )
        self.fc = nn.Sequential(
            nn.Linear(128 * 6 * 6, 1),
            nn.Sigmoid()
        )

    def forward(self, x, c):
        c = self.linear(c)
        c = c.view(-1, 1, 28, 28)
        # 2 x 28 x 28
        x = torch.cat([x, c], dim=1)
        x = self.model(x)
        x = x.view(-1, 128 * 6 * 6)
        x = self.fc(x)
        return x

判别器也是接收两个参数,一个是图像(可能是真实图像,也可能是生成的虚假的图像),另一个是标签,首先将标签转换为1x28x28的形状,然后将这个向量和图像连接起来,构成一个2x28x28的向量,最后经过卷积、激活、池化、线形层输出一个结果(真或者假)。

  1. 训练
  • 训练判别器

判别器要尽可能地区分出真实图片和虚假的图片;
将真实的图像和标签放入到判别器中,计算判别器输出与1之间的损失;
根据噪声生成虚假的图片,将虚假的图片和标签放入到判别器中,计算判别器输出和0之间的损失;
反向传播、迭代优化。

  • 训练生成器

生成器要尽可能的使生成的图像接近真实的图像,让判别器无法判断出图片的来源(真实还是生成);
将生成的虚假的图片放入到判别器中,计算判别器的输出与1之间的损失;
反向传播、迭代优化。

4. 代码

import torch
from torch import nn, cuda
from torch.utils import data
from torch.utils.tensorboard import SummaryWriter
import torchvision
from torchvision import transforms
import numpy as np
from tqdm import tqdm
import os

os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"


def one_hot(x, class_count=10):
    return torch.eye(class_count)[x, :]


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        # 两个128 x 7 x 7 cat后依然为256 x 7 x 7
        self.linear1 = nn.Sequential(
            nn.Linear(100, 128 * 7 * 7),
            nn.ReLU(),
            nn.BatchNorm1d(128 * 7 * 7),
        )
        self.linear2 = nn.Sequential(
            nn.Linear(10, 128 * 7 * 7),
            nn.ReLU(),
            nn.BatchNorm1d(128 * 7 * 7),
        )

        self.model = nn.Sequential(
            # 128 x 7 x 7
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            # 64 x 14 x 14
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            # 1 x 28 x 28
            nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x, c):
        x = self.linear1(x)
        x = x.view(-1, 128, 7, 7)
        c = self.linear2(c)
        c = c.view(-1, 128, 7, 7)
        # 256 x 7 x 7
        # 在channels方面合并
        x = torch.cat([x, c], dim=1)
        return self.model(x)


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        # input: 1 x 28 x 28 + 10 condition
        self.linear = nn.Sequential(
            nn.Linear(10, 1 * 28 * 28),
            nn.ReLU()
        )
        self.model = nn.Sequential(
            nn.Conv2d(2, 64, kernel_size=3, stride=2),
            nn.LeakyReLU(),
            nn.Dropout(p=0.3),
            nn.Conv2d(64, 128, kernel_size=3, stride=2),
            nn.LeakyReLU(),
            nn.Dropout(p=0.3),
            nn.BatchNorm2d(128),
        )
        self.fc = nn.Sequential(
            nn.Linear(128 * 6 * 6, 1),
            nn.Sigmoid()
        )

    def forward(self, x, c):
        c = self.linear(c)
        c = c.view(-1, 1, 28, 28)
        # 2 x 28 x 28
        x = torch.cat([x, c], dim=1)
        x = self.model(x)
        x = x.view(-1, 128 * 6 * 6)
        x = self.fc(x)
        return x


transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=0.5, std=0.5)
])

dataset = torchvision.datasets.MNIST("./data", train=True,
                                     transform=transform,
                                     download=True,
                                     target_transform=one_hot)

dataloader = data.DataLoader(dataset, batch_size=512, shuffle=True)

device = "cuda" if torch.cuda.is_available() else "cpu"
gen = Generator().to(device)
disc = Discriminator().to(device)
loss_fn = torch.nn.BCELoss()
opt_g = torch.optim.RMSprop(gen.parameters(), lr=0.0001)
opt_d = torch.optim.Adam(disc.parameters(), lr=0.0001)
num_epochs = 201

writer_g = SummaryWriter("/root/tf-logs/g")
writer_d = SummaryWriter("/root/tf-logs/d")

noise_seed = torch.randn(16, 100, device=device)
# 16个0-10之间的随机整数
label_seed = torch.randint(0, 10, size=(16,))
print(f"label seed: {label_seed}")
print(type(label_seed))
label_seed_onehot = one_hot(label_seed).to(device)
print(f"label_seed: {label_seed}")

for epoch in range(num_epochs):
    D_epoch_loss = 0
    G_epoch_loss = 0
    count = len(dataset)
    loop = tqdm(dataloader, leave=True, desc=f"Epoch: {epoch}/{num_epochs}")
    for step, (img, label) in enumerate(loop):
        img = img.to(device)
        label = label.to(device)
        size = img.shape[0]
        random_seed = torch.randn(size, 100, device=device)

        # 训练判别器
        opt_d.zero_grad()
        # 真实图片放入判别器中
        real_output = disc(img, label)
        d_real_loss = loss_fn(real_output, torch.ones_like(real_output, device=device))
        # 生成图像并放入判别器中
        gen_img = gen(random_seed, label)
        fake_output = disc(gen_img.detach(), label)
        d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output, device=device))
        d_loss = (d_real_loss + d_fake_loss) / 2
        d_loss.backward()
        opt_d.step()

        # 训练生成器
        opt_g.zero_grad()
        fake_output = disc(gen_img, label)
        g_loss = loss_fn(fake_output, torch.ones_like(fake_output, device=device))
        g_loss.backward()
        opt_g.step()

        with torch.no_grad():
            D_epoch_loss += d_loss.item()
            G_epoch_loss += g_loss.item()
            loop.set_postfix(G_loss=f"{np.round(G_epoch_loss, 2)}", D_loss=f"{np.round(D_epoch_loss, 2)}")

    with torch.no_grad():
        D_epoch_loss /= count
        G_epoch_loss /= count
        writer_g.add_scalar("loss", G_epoch_loss, epoch)
        writer_d.add_scalar("loss", D_epoch_loss, epoch)

        if epoch % 20 == 0:
            with torch.no_grad():
                gen_img = gen(noise_seed, label_seed_onehot)
                writer_g.add_images("gen mnist", gen_img, epoch)

torch.save(gen.state_dict(), "./gen.pth")

5. 训练结果

5. 参考资料

  1. https://arxiv.org/abs/1411.1784
  2. https://blog.csdn.net/qq_41647438/article/details/103007057
  3. https://blog.csdn.net/***_xujiping/article/details/102719363
  4. https://zhuanlan.zhihu.com/p/510346635
  5. https://www.jianshu.com/p/39c57e9a6630

本文作者:机智的程序DOG

本文链接:https://www.cnblogs.com/Elijah-Z/p/16769594.html

版权声明:本作品采用知识共享署名-非商业性使用-禁止演绎 2.5 中国大陆许可协议进行许可。

posted @   机智的程序DOG  阅读(368)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· winform 绘制太阳,地球,月球 运作规律
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· AI与.NET技术实操系列(五):向量存储与相似性搜索在 .NET 中的实现
· 超详细:普通电脑也行Windows部署deepseek R1训练数据并当服务器共享给他人
点击右上角即可分享
微信分享提示
评论
收藏
关注
推荐
深色
回顶
收起