利用Libtorch c++创建并训练DCGAN网络生成手写数字MNIST

目录

什么是对抗生成网络GAN

生成网络模块

鉴别网络模块

数据集定义

数据加载

数据检查的输出结果

定义优化器

网络训练

模型的定期保存

全部源代码


我们的目的是从MNIST 数据集生成图片,将使用对抗生成网络(GAN)完成这个任务。具体说,将采用DCGAN架构,它是最早最简单的对抗生成网络,但足以完成这项任务。

什么是对抗生成网络GAN

GAN由两个不同的神经网络模型组成:一个生成器和一个鉴别器。生成器接收来自噪声分布的样本,其目的是将每个噪声样本转换为类似于目标分布的图像(在我们的例子中是MNIST数据集)。鉴别器依次从MNIST数据集接收真实图像,或从生成器接收图像。它被要求发出一个概率来判断一个特定图像是真实的(接近“1”)还是假的(接近“0”)。从鉴别器对生成器产生的图像的真实性反馈用于进一步训练生成器。关于真假图片的反馈用于优化鉴别器。理论上,生成器和鉴别器之间的微妙平衡使它们协同改进,导致生成器生成的图像与目标分布不可区分,从而欺骗鉴别器(那时)的优秀眼睛,使真实图像和伪图像的概率都达到“0.5”。对我们来说,最终的结果是一台机器,它接收噪声作为输入,并生成数字的真实图像作为输出。

生成网络模块

生成模块包含一系列的二维转置卷积、批量正态转化、ReLU激活单元。在forward方法把多个模块之间传递输入和输出。

生成网络的作用是接受一个随机数组成的数据序列,生成一个灰度图片。输入层的通道为kNoiseSize=100,尺寸为1*1,输出的通道为256,尺寸为4*4,然后依次得到尺寸为7*7, 14*14, 28*28的图像,MNIST数据集的图像尺寸就是28*28。采用二维卷积转置的尺寸计算公式,

H_{out}=(H_{in}-1)*stride[0]-2*padding[0]+kernel_size[0]+output_padding[0]

尺寸变化过程如下:

自定义的网络继承nn::Module模块,这里采用了初始化列表的方式,一个好处是不需要再定义复杂的构造函数,第二个好处是使用初始化列表少了一次调用默认构造函数的过程,这对于数据密集型的类来说,是非常高效的。。另外,由于c++语言自身没有反射功能,要求每个网络层都需要通过register_module()函数进行手动注册。另外,定义网络之后,也需要通过宏的方式注册自定义的网络模块,TORCH_MODULE(DCGANGenerator);

struct DCGANGeneratorImpl : nn::Module {
    DCGANGeneratorImpl(int kNoiseSize)
        : conv1(nn::ConvTranspose2dOptions(kNoiseSize, 256, 4)
            .bias(false)),
        batch_norm1(256),
        conv2(nn::ConvTranspose2dOptions(256, 128, 3)
            .stride(2)
            .padding(1)
            .bias(false)),
        batch_norm2(128),
        conv3(nn::ConvTranspose2dOptions(128, 64, 4)
            .stride(2)
            .padding(1)
            .bias(false)),
        batch_norm3(64),
        conv4(nn::ConvTranspose2dOptions(64, 1, 4)
            .stride(2)
            .padding(1)
            .bias(false))
    {
        // register_module() is needed if we want to use the parameters() method later on
        register_module("conv1", conv1);
        register_module("conv2", conv2);
        register_module("conv3", conv3);
        register_module("conv4", conv4);
        register_module("batch_norm1", batch_norm1);
        register_module("batch_norm2", batch_norm2);
        register_module("batch_norm3", batch_norm3);
    }

    torch::Tensor forward(torch::Tensor x) {
        x = torch::relu(batch_norm1(conv1(x)));
        x = torch::relu(batch_norm2(conv2(x)));
        x = torch::relu(batch_norm3(conv3(x)));
        x = torch::tanh(conv4(x));
        return x;
    }

    nn::ConvTranspose2d conv1, conv2, conv3, conv4;
    nn::BatchNorm2d batch_norm1, batch_norm2, batch_norm3;
};
TORCH_MODULE(DCGANGenerator);

鉴别网络模块

鉴别器类网络采用类似于卷积、批量规范化和激活的序列。然而,卷积现在是正常卷积而不是转置卷积,我们使用alpha值为0.2的leaky ReLU而不是vanilla ReLU。鉴别网络接收真实图像或生成器生成的假图像,通过一些列卷积和正态化操作,最终激活变成一个S型函数Sigmoid,它将值压缩到0到1之间的范围内。然后我们可以将这些压缩值解释为鉴别器赋予图像真实性的概率。

鉴别网络的通过多次的卷积计算,尺寸从28*28变为1,卷积计算尺寸的计算公式是:

$$L_{out}=floor((L_{in}+2padding-dilation(kernerl_size-1)-1)/stride+1)$$

具体变化过程是:

为了构建鉴别器,我们将尝试一些不同的东西:顺序模块Sequential Module。与Python一样,LibTorch在这里提供了两个用于模型定义的api:一个是通过连续函数传递输入的函数api(例如生成器模块示例),另一个是更面向对象的api,在这里我们构建了一个包含整个模型作为子模块的顺序模块。Sequential模块简单地执行函数的组合,第一个子模块的输出变成第二个子模块的输入,第三个模块的输出是第四个模块的输入,不需要再自定义forwrd函数。nn::Sequential类型的构造函数的参数是nn网络的列表。 

nn::Sequential discriminator(
    // Layer 1
    nn::Conv2d(
        nn::Conv2dOptions(1, 64, 4).stride(2).padding(1).bias(false)),
    nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
    // Layer 2
    nn::Conv2d(
        nn::Conv2dOptions(64, 128, 4).stride(2).padding(1).bias(false)),
    nn::BatchNorm2d(128),
    nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
    // Layer 3
    nn::Conv2d(
        nn::Conv2dOptions(128, 256, 4).stride(2).padding(1).bias(false)),
    nn::BatchNorm2d(256),
    nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
    // Layer 4
    nn::Conv2d(
        nn::Conv2dOptions(256, 1, 3).stride(1).padding(0).bias(false)),
    nn::Sigmoid());

数据集定义

定义生成模块和判别模块之后,需要加载可以用来训练的数据。面向c++的接口,与python类似,提供了强悍的并行数据加载器。该数据加载器可以从数据集中批量地加载数据, 并且提供了很多配置选项。

尽管python的数据加载器也是多进程的,c++的数据加载器是真正的多线程,不开启任意的新进程。

数据加载器是面向C++接口的一部分,包含在torch::data:: 名称空间下,这些接口包括不同的组成:

  • 数据加载器类
  • 定义数据集的接口
  • 定义数据转换的接口,它可以作用域数据集
  • 定义数据采样器,它可以生成数据集的索引
  • 已有的(内置的)数据集、转化和采样器

在这个例子中,我们可以使用面向c++接口提供的MNIST内置数据集,torch::data::datasets::MNIST,并执行两次变换。首先,把图片做正太化变化,这样可以把数据转化为-1到+1之间(原始数据是0到1之间),然后,应用入栈列队(Stack collation),它提取一组矩阵,并沿着第一维度把单个的矩阵入栈。

dataset的size()函数可以获取数据集的个数,除以批量加载时数据集kBatchSize的多少,可以计算需要多少批加载。

 //load the dataset
    auto dataset = torch::data::datasets::MNIST("../../pytorchCpp/data/mnist/MNIST/raw")
        .map(torch::data::transforms::Normalize<>(0.5, 0.5))
        .map(torch::data::transforms::Stack<>());
    const int64_t batches_per_epoch = std::ceil(dataset.size().value() / static_cast<double>(kBatchSize));

数据加载

下一步,是创建数据加载器,并把数据集传递给它。通过torch::data::make_data_loader创建新的数据加载器,它返回数据集正确类型的地址std::unique_ptr(它依赖数据集的类型,采样器的类型和其他实现的细节)。

数据加载器有很多选项。您可以[在此处]检查全部选项。例如,为了加快数据加载速度,我们可以增加线程的数量。默认值为零,这意味着将使用主线程。如果将“workers”设置为“2”,将生成两个同时加载数据的线程。我们还应该将批大小从默认的'1'增加到更合理的值,比如'64'(kBatchSize的值)。因此,让我们创建一个“DataLoaderOptions”对象并设置适当的属性:

//define the data_loader
    auto data_loader = torch::data::make_data_loader(
        std::move(dataset),
        torch::data::DataLoaderOptions().batch_size(kBatchSize).workers(2));

数据检查的输出结果

数据加载器返回的数据类型是torch::data::Example,这个数据类型是简单的结构,拥有data字段存储数据,和target字段存储标签。因为前面使用了入栈操作,这里数据加载器仅返回单个样本,如果不进行入栈操作,则数据加载器返回的是列表形式。std::vectortorch::data::Example<>,每个元素是一批样本。

 //print to check the data
    for (torch::data::Example<>& batch : *data_loader) {
        std::cout << "Batch size: " << batch.data.size(0) << " | Labels: ";
        for (int64_t i = 0; i < batch.data.size(0); ++i) {
            std::cout << batch.target[i].item<int64_t>() << " ";
        }
        std::cout << std::endl;
    }

这是kBatchSize=64时的数据检查结果

Batch size: 64 | Labels: 8 7 5 9 1 0 5 9 5 1 7 9 5 7 1 0 6 7 5 2 8 2 2 7 0 2 4 1 8 7 8 7 5 0 2 0 2 7 7 6 5 8 5 8 5 1 6 1 1 0 9 8 7 4 0 5 4 9 8 9 0 3 9 2
Batch size: 64 | Labels: 5 2 6 7 5 3 7 4 2 1 5 3 2 6 2 1 7 6 4 4 7 4 9 7 6 5 4 7 9 2 2 1 7 4 0 8 6 0 5 1 2 8 9 5 9 9 6 9 7 8 1 2 1 0 3 2 3 9 2 5 5 2 8 7
Batch size: 64 | Labels: 9 3 9 5 6 7 6 6 8 2 3 9 8 8 0 1 9 2 2 8 4 0 1 2 7 9 6 8 9 5 6 6 9 4 3 7 8 5 2 9 3 0 7 5 2 8 2 9 7 5 4 3 2 1 9 8 7 2 7 2 0 8 3 3
Batch size: 64 | Labels: 4 7 0 1 3 6 4 8 0 3 2 2 4 8 8 4 8 8 6 5 6 5 7 8 1 9 2 5 3 2 8 5 8 0 6 9 5 7 9 8 5 2 4 4 6 8 2 0 5 0 0 4 3 5 0 9 0 3 2 8 8 1 1 6
Batch size: 64 | Labels: 4 1 0 3 9 6 2 1 1 5 3 4 3 7 7 7 4 4 7 4 5 3 4 1 0 7 8 1 6 0 6 8 8 4 1 8 4 0 3 3 1 9 7 5 6 2 4 1 3 8 9 4 7 1 0 8 6 8 9 8 5 7 2 5
Batch size: 64 | Labels: 0 6 3 9 9 6 0 9 9 3 0 0 0 5 9 0 9 6 9 8 1 8 7 5 1 0 1 1 6 8 4 7 6 2 8 1 8 6 7 8 5 8 9 6 1 2 9 3 8 2 0 8 4 7 6 9 6 1 1 1 4 2 8 8
Batch size: 64 | Labels: 3 0 8 9 3 5 4 9 6 3 2 3 3 9 7 9 6 0 7 2 7 8 2 4 8 7 9 3 4 7 9 0 5 6 3 8 1 1 3 9 9 1 6 3 7 3 1 7 0 1 5 6 2 1 2 1 7 8 7 9 6 2 7 7
Batch size: 64 | Labels: 7 8 7 3 1 7 7 4 9 1 4 6 7 6 4 2 0 8 1 0 5 5 8 4 1 1 8 9 5 3 1 7 4 1 2 8 1 7 8 5 7 4 0 3 8 3 8 3 6 3 7 0 4 2 1 1 8 2 8 5 7 6 5 0
Batch size: 64 | Labels: 2 9 8 6 9 1 4 5 8 9 0 2 5 7 2 9 3 9 4 1 3 5 0 1 1 4 0 4 6 9 0 1 9 6 9 5 4 9 7 4 0 6 2 0 7 6 6 8 6 0 9 9 6 2 9 8 5 2 2 3 4 8 7 7
Batch size: 64 | Labels: 8 4 9 8 5 8 4 2 9 8 0 0 1 9 1 8 6 3 2 3 4 0 2 2 5 6 6 0 7 1 9 9 1 1 8 7 9 3 1 8 2 1 0 9 1 7 2 3 1 3 8 2 8 2 9 6 5 0 1 2 1 6 8 6
.....

定义优化器

下面要这个例子的算法部分,并实现生成器和判别器的微妙双人舞。首先创建两个优化器,一个用来优化生成器,一个用来优化判别器。用到的优化器是Adam算法。

就像这个例子用到的,面向c++的接口提供了Adagrad, Adam, LBFGS, RMSprop and SGD等优化算法的实现,具体的优化算法列表可以看这个文档

 //define the optimizer for these two net
    torch::optim::Adam generator_optimizer(
        generator->parameters(), torch::optim::AdamOptions(2e-4).betas(std::make_tuple(0.5, 0.999)));
    torch::optim::Adam discriminator_optimizer(
        discriminator->parameters(), torch::optim::AdamOptions(5e-4).betas(std::make_tuple(0.5, 0.999)));

网络训练

下面就是更新训练循环模块,需要增加两个训练,在数据加载器获取每组数据,然后在每组中训练对抗生成网络模型。

在训练中,首先在真实图片上训练鉴别器,把真是图片赋予很高的概率,这里通过torch::empty(batch.data.size(0)).uniform_(0.8, 1.0)作为标签的概率。 选择0.8到1.0的均匀分布作为目标概率是为了鉴别器训练更稳健。这种技巧称为标签光滑。

在评价判别器之前,我么需要把鉴别器的梯度参数归零,计算损失之后,通过调用d_loss.backward()可是执行神经网络的反向传播算法,计算新的梯度。不仅在真实的数据集上,在合成的图片上也执行这个过程。合成的图片是通过生成网络计算得到,生成网络的输入数据是个随机噪声序列。把生成的图片传递给鉴别器,想让它给出个很低的真实度判别结果,理想结果是0。计算判别器在真实样本和合成图片上的损失之后,可以通过优化器更新它的参数。

为了训练生成器,也需要把生成器的梯度归零,然后重新判别在合成数据上的表现。但是此时,需要把合成数据的标签赋值为概率1,意味着生成器可以生成让判别认为是真的结果。为此,需要给表爱你fake_labels赋值为1,最后是用鉴别器的优化算法更新参数。

for (int64_t epoch = 1; epoch <= kNumberOfEpochs; ++epoch) {
        int64_t batch_index = 0;
        for (torch::data::Example<>& batch : *data_loader) {
            // Train discriminator with real images.
            discriminator->zero_grad();
            torch::Tensor real_images = batch.data;
            torch::Tensor real_labels = torch::empty(batch.data.size(0)).uniform_(0.8, 1.0);
            torch::Tensor real_output = discriminator->forward(real_images);
            torch::Tensor d_loss_real = torch::binary_cross_entropy(real_output, real_labels);
            d_loss_real.backward();

            // Train discriminator with fake images.
            torch::Tensor noise = torch::randn({ batch.data.size(0), kNoiseSize, 1, 1 });
            torch::Tensor fake_images = generator->forward(noise);
            torch::Tensor fake_labels = torch::zeros(batch.data.size(0));
            torch::Tensor fake_output = discriminator->forward(fake_images.detach());
            torch::Tensor d_loss_fake = torch::binary_cross_entropy(fake_output, fake_labels);
            d_loss_fake.backward();

            torch::Tensor d_loss = d_loss_real + d_loss_fake;
            discriminator_optimizer.step();

            // Train generator.
            generator->zero_grad();
            fake_labels.fill_(1);
            fake_output = discriminator->forward(fake_images);
            torch::Tensor g_loss = torch::binary_cross_entropy(fake_output, fake_labels);
            g_loss.backward();
            generator_optimizer.step();

            std::printf(
                "\r[%2ld/%2ld][%3ld/%3ld] D_loss: %.4f | G_loss: %.4f",
                epoch,
                kNumberOfEpochs,
                ++batch_index,
                batches_per_epoch,
                d_loss.item<float>(),
                g_loss.item<float>());
        }
    }

到此,我们基本上可以在cpu上训练模型。但是,目前还没有涉及捕捉状态或样本的输出,后面会提到。现在,模型可以做的事,主要依赖于生成的图片是否看起来有意义(像真实的一样)。代码执行的运行结果如下:

[ 1/30][200/938] D_loss: 0.3507 | G_loss: 7.64503
-> checkpoint 2
[ 1/30][400/938] D_loss: 2.7487 | G_loss: 3.29385
-> checkpoint 3
[ 1/30][600/938] D_loss: 0.9987 | G_loss: 1.8063
-> checkpoint 4
[ 1/30][800/938] D_loss: 0.7328 | G_loss: 1.8110
-> checkpoint 5
[ 2/30][200/938] D_loss: 0.9540 | G_loss: 0.9474
-> checkpoint 6
[ 2/30][400/938] D_loss: 0.7088 | G_loss: 2.2973
-> checkpoint 7
[ 2/30][600/938] D_loss: 0.4907 | G_loss: 2.4834
-> checkpoint 8
[ 2/30][800/938] D_loss: 0.5548 | G_loss: 2.5090
-> checkpoint 9
[ 3/30][200/938] D_loss: 0.6886 | G_loss: 3.2052
-> checkpoint 10
[ 3/30][400/938] D_loss: 0.5958 | G_loss: 2.8089
-> checkpoint 11
[ 3/30][522/938] D_loss: 0.6508 | G_loss: 2.5090

模型的定期保存

if (batch_index % kCheckpointEvery == 0) {
  // Checkpoint the model and optimizer state.
  torch::save(generator, "generator-checkpoint.pt");
  torch::save(generator_optimizer, "generator-optimizer-checkpoint.pt");
  torch::save(discriminator, "discriminator-checkpoint.pt");
  torch::save(discriminator_optimizer, "discriminator-optimizer-checkpoint.pt");
  // Sample the generator and save the images.
  torch::Tensor samples = generator->forward(torch::randn({8, kNoiseSize, 1, 1}));
  torch::save((samples + 1.0) / 2.0, torch::str("dcgan-sample-", checkpoint_counter, ".pt"));
  std::cout << "\n-> checkpoint " << ++checkpoint_counter << '\n';
}

全部源代码

#include <iostream>
#include <tuple>
#include <torch/script.h>
#include <torch/csrc/api/include/torch/torch.h>
#include <torch/nn.h>
#include <torch/optim.h>
#include <torch/torch.h>

using namespace torch;

struct DCGANGeneratorImpl : nn::Module {
    DCGANGeneratorImpl(int kNoiseSize)
        : conv1(nn::ConvTranspose2dOptions(kNoiseSize, 256, 4)
            .bias(false)),
        batch_norm1(256),
        conv2(nn::ConvTranspose2dOptions(256, 128, 3)
            .stride(2)
            .padding(1)
            .bias(false)),
        batch_norm2(128),
        conv3(nn::ConvTranspose2dOptions(128, 64, 4)
            .stride(2)
            .padding(1)
            .bias(false)),
        batch_norm3(64),
        conv4(nn::ConvTranspose2dOptions(64, 1, 4)
            .stride(2)
            .padding(1)
            .bias(false))
    {
        // register_module() is needed if we want to use the parameters() method later on
        register_module("conv1", conv1);
        register_module("conv2", conv2);
        register_module("conv3", conv3);
        register_module("conv4", conv4);
        register_module("batch_norm1", batch_norm1);
        register_module("batch_norm2", batch_norm2);
        register_module("batch_norm3", batch_norm3);
    }

    torch::Tensor forward(torch::Tensor x) {
        x = torch::relu(batch_norm1(conv1(x)));
        x = torch::relu(batch_norm2(conv2(x)));
        x = torch::relu(batch_norm3(conv3(x)));
        x = torch::tanh(conv4(x));
        return x;
    }

    nn::ConvTranspose2d conv1, conv2, conv3, conv4;
    nn::BatchNorm2d batch_norm1, batch_norm2, batch_norm3;
};
TORCH_MODULE(DCGANGenerator);

nn::Sequential discriminator(
    // Layer 1
    nn::Conv2d(
        nn::Conv2dOptions(1, 64, 4).stride(2).padding(1).bias(false)),
    nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
    // Layer 2
    nn::Conv2d(
        nn::Conv2dOptions(64, 128, 4).stride(2).padding(1).bias(false)),
    nn::BatchNorm2d(128),
    nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
    // Layer 3
    nn::Conv2d(
        nn::Conv2dOptions(128, 256, 4).stride(2).padding(1).bias(false)),
    nn::BatchNorm2d(256),
    nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
    // Layer 4
    nn::Conv2d(
        nn::Conv2dOptions(256, 1, 3).stride(1).padding(0).bias(false)),
    nn::Sigmoid());


int main() {
    // The size of the noise vector fed to the generator.
    const int64_t kNoiseSize = 100;

    // The batch size for training.
    const int64_t kBatchSize = 64;

    // The number of epochs to train.
    const int64_t kNumberOfEpochs = 30;

    // Where to find the MNIST dataset.
    const char* kDataFolder = "./data";

    // After how many batches to create a new checkpoint periodically.
    const int64_t kCheckpointEvery = 200;

    // How many images to sample at every checkpoint.
    const int64_t kNumberOfSamplesPerCheckpoint = 10;

    // Set to `true` to restore models and optimizers from previously saved
    // checkpoints.
    const bool kRestoreFromCheckpoint = false;

    // After how many batches to log a new update with the loss value.
    const int64_t kLogInterval = 10;

    DCGANGenerator generator(kNoiseSize);

    //load the dataset
    auto dataset = torch::data::datasets::MNIST("../../pytorchCpp/data/mnist/MNIST/raw")
        .map(torch::data::transforms::Normalize<>(0.5, 0.5))
        .map(torch::data::transforms::Stack<>());
    const int64_t batches_per_epoch = std::ceil(dataset.size().value() / static_cast<double>(kBatchSize));

    //define the data_loader
    auto data_loader = torch::data::make_data_loader(
        std::move(dataset),
        torch::data::DataLoaderOptions().batch_size(kBatchSize).workers(2));


    //print to check the data
    for (torch::data::Example<>& batch : *data_loader) {
        std::cout << "Batch size: " << batch.data.size(0) << " | Labels: ";
        for (int64_t i = 0; i < batch.data.size(0); ++i) {
            std::cout << batch.target[i].item<int64_t>() << " ";
        }
        std::cout << std::endl;
    }
    //define the optimizer for these two net
    torch::optim::Adam generator_optimizer(
        generator->parameters(), torch::optim::AdamOptions(2e-4).betas(std::make_tuple(0.5, 0.999)));
    torch::optim::Adam discriminator_optimizer(
        discriminator->parameters(), torch::optim::AdamOptions(5e-4).betas(std::make_tuple(0.5, 0.999)));

    int64_t checkpoint_counter = 0;
    for (int64_t epoch = 1; epoch <= kNumberOfEpochs; ++epoch) {
        int64_t batch_index = 0;
        for (torch::data::Example<>& batch : *data_loader) {
            // Train discriminator with real images.
            discriminator->zero_grad();
            torch::Tensor real_images = batch.data;
            torch::Tensor real_labels = torch::empty(batch.data.size(0)).uniform_(0.8, 1.0);
            torch::Tensor real_output = discriminator->forward(real_images);
            torch::Tensor d_loss_real = torch::binary_cross_entropy(real_output, real_labels);
            d_loss_real.backward();

            // Train discriminator with fake images.
            torch::Tensor noise = torch::randn({ batch.data.size(0), kNoiseSize, 1, 1 });
            torch::Tensor fake_images = generator->forward(noise);
            torch::Tensor fake_labels = torch::zeros(batch.data.size(0));
            torch::Tensor fake_output = discriminator->forward(fake_images.detach());
            torch::Tensor d_loss_fake = torch::binary_cross_entropy(fake_output, fake_labels);
            d_loss_fake.backward();

            torch::Tensor d_loss = d_loss_real + d_loss_fake;
            discriminator_optimizer.step();

            // Train generator.
            generator->zero_grad();
            fake_labels.fill_(1);
            fake_output = discriminator->forward(fake_images);
            torch::Tensor g_loss = torch::binary_cross_entropy(fake_output, fake_labels);
            g_loss.backward();
            generator_optimizer.step();

            //print the status
            std::printf(
                "\r[%2ld/%2ld][%3ld/%3ld] D_loss: %.4f | G_loss: %.4f",
                epoch,
                kNumberOfEpochs,
                ++batch_index,
                batches_per_epoch,
                d_loss.item<float>(),
                g_loss.item<float>());

            //save current model
            if (batch_index % kCheckpointEvery == 0) {
                // Checkpoint the model and optimizer state.
                torch::save(generator, "generator-checkpoint.pt");
                torch::save(generator_optimizer, "generator-optimizer-checkpoint.pt");
                torch::save(discriminator, "discriminator-checkpoint.pt");
                torch::save(discriminator_optimizer, "discriminator-optimizer-checkpoint.pt");
                // Sample the generator and save the images.
                torch::Tensor samples = generator->forward(torch::randn({ 8, kNoiseSize, 1, 1 }));
                torch::save((samples + 1.0) / 2.0, torch::str("dcgan-sample-", checkpoint_counter, ".pt"));
                std::cout << "\n-> checkpoint " << ++checkpoint_counter << '\n';
            }
        }
    }
}

 

posted @ 2022-08-21 10:13  Oliver2022  阅读(87)  评论(0编辑  收藏  举报