Libtorch c++ 搭建全连接网络识别MINST手写数字

这是个完整的例子,用全连接网络方法识别手写数字,分为三部分,(1)搭建网络,(2)读取MNIST数据,(3)优化器设置,(4)训练网络。

1、网络搭建部分

用struct的方式建立自定义网络Net,它继承自torch::nn::Module,实现了forward函数,

该网络中注册的内置网络模块是三个线性网络,fc1,fc2,fc3

神经元的个数,fc1为(784,64),fc2(64,32),fc3(32,10),fc1的输入层个数为784,是因为MNIST图像的像素为28*28=784,网络的最后一层是log_softmax,输入是fc3的输出,fc3的输出神经元个数为10。它的输出神经元个数与输入个数一致,还是10,log_softmax()层的作用是接受一个实数向量计算概率分布然后取对数。整体的网络结构如下:

MNIST picture->tensor->fc1(784, 64)->relu()->dropout()->fc2(64,32)->relu()->fc3(32, 10)->log_softmax()->prediction

可以知道输入是长度为784的向量,输出是长度为10的概率分布的对数。

2、数据读取部分

与pytorch类似,libtorch也要求必须用make_data_loader这种多线形成数据加载的方式加载数据。这个没有新建数据集,而是采用了libtorch自带的MNIST数据集(libtorch通过torch::data::datasets封装了一些常见的数据集,方便使用),直接加载即可。然后通过map映射的方式把数据形成批量的,数据转化方式是栈(stack)的方式

3、优化器设置

优化器设置比较简单,选择的是随机提取下降,固定学习率。

4、训练部分

通过for训练的方式进行迭代训练,每个循环内,重置梯度为0,然后计算损失函数loss,通过loss的反向计算backward计算梯度,通过优化器optimizer.step()调正网络的权重,然后执行下次训练,每个一定次数保存当前的网络。

其中损失函数选择的是nll_loss,它输入要求是一个对数概率向量和一个目标标签. 它不会为我们计算对数概率. 适合网络的最后一层是log_softmax.

保存当前网络的方法是,torch::save(net, "net.pt")

#include <iostream>
#include <torch/script.h>
#include <torch/csrc/api/include/torch/torch.h>

// Define a new Module.
struct Net : torch::nn::Module {
    Net() {
        // Construct and register two Linear submodules.
        fc1 = register_module("fc1", torch::nn::Linear(784, 64));
        fc2 = register_module("fc2", torch::nn::Linear(64, 32));
        fc3 = register_module("fc3", torch::nn::Linear(32, 10));
    }

    // Implement the Net's algorithm.
    torch::Tensor forward(torch::Tensor x) {
        // Use one of many tensor manipulation functions.
        x = torch::relu(fc1->forward(x.reshape({ x.size(0), 784 })));
        x = torch::dropout(x, /*p=*/0.5, /*train=*/is_training());
        x = torch::relu(fc2->forward(x));
        x = torch::log_softmax(fc3->forward(x), /*dim=*/1);
        return x;
    }

    // Use one of many "standard library" modules.
    torch::nn::Linear fc1{ nullptr }, fc2{ nullptr }, fc3{ nullptr };
};

int main()
{
    std::cout << "Hello World!\n";
    //torch::Tensor tensor = torch::eye(3);
    torch::Tensor tensor = torch::rand({ 2,3 });
    std::cout << tensor << std::endl;
    torch::save(tensor, "tensor.pt");

    const static int WIDTH = 512, HEIGHT = 512;

    // Create a new Net.
    auto net = std::make_shared<Net>();

    // Create a multi-threaded data loader for the MNIST dataset.
    auto data_loader = torch::data::make_data_loader(
        torch::data::datasets::MNIST("../../pytorchCpp/data/mnist/MNIST/raw").map(
            torch::data::transforms::Stack<>()),
        /*batch_size=*/64);

    // Instantiate an SGD optimization algorithm to update our Net's parameters.
    torch::optim::SGD optimizer(net->parameters(), /*lr=*/0.01);
    
    std::vector<double> lossVec;
    for (size_t epoch = 1; epoch <= 10; ++epoch) {
        size_t batch_index = 0;
        // Iterate the data loader to yield batches from the dataset.
        for (auto& batch : *data_loader) {
            // Reset gradients.
            optimizer.zero_grad();
            // Execute the model on the input data.
            torch::Tensor prediction = net->forward(batch.data);
            // Compute a loss value to judge the prediction of our model.
            torch::Tensor loss = torch::nll_loss(prediction, batch.target);
            // Compute gradients of the loss w.r.t. the parameters of our model.
            loss.backward();
            // Update the parameters based on the calculated gradients.
            optimizer.step();
            // Output the loss and checkpoint every 100 batches.
            if (++batch_index % 100 == 0) {
                std::cout << "Epoch: " << epoch << " | Batch: " << batch_index
                    << " | Loss: " << loss.item<float>() << std::endl;
                lossVec.push_back(loss.item<double>());
                // Serialize your model periodically as a checkpoint.
                torch::save(net, "net.pt");
            }
            
        }
    }
  
        
    
}

 

 损失函数

posted @ 2022-08-21 10:13  Oliver2022  阅读(221)  评论(0编辑  收藏  举报