Libtorch c++ 搭建全连接网络识别MINST手写数字
这是个完整的例子,用全连接网络方法识别手写数字,分为三部分,(1)搭建网络,(2)读取MNIST数据,(3)优化器设置,(4)训练网络。
1、网络搭建部分
用struct的方式建立自定义网络Net,它继承自torch::nn::Module,实现了forward函数,
该网络中注册的内置网络模块是三个线性网络,fc1,fc2,fc3
神经元的个数,fc1为(784,64),fc2(64,32),fc3(32,10),fc1的输入层个数为784,是因为MNIST图像的像素为28*28=784,网络的最后一层是log_softmax,输入是fc3的输出,fc3的输出神经元个数为10。它的输出神经元个数与输入个数一致,还是10,log_softmax()层的作用是接受一个实数向量计算概率分布然后取对数。整体的网络结构如下:
MNIST picture->tensor->fc1(784, 64)->relu()->dropout()->fc2(64,32)->relu()->fc3(32, 10)->log_softmax()->prediction
可以知道输入是长度为784的向量,输出是长度为10的概率分布的对数。
2、数据读取部分
与pytorch类似,libtorch也要求必须用make_data_loader这种多线形成数据加载的方式加载数据。这个没有新建数据集,而是采用了libtorch自带的MNIST数据集(libtorch通过torch::data::datasets封装了一些常见的数据集,方便使用),直接加载即可。然后通过map映射的方式把数据形成批量的,数据转化方式是栈(stack)的方式
3、优化器设置
优化器设置比较简单,选择的是随机提取下降,固定学习率。
4、训练部分
通过for训练的方式进行迭代训练,每个循环内,重置梯度为0,然后计算损失函数loss,通过loss的反向计算backward计算梯度,通过优化器optimizer.step()调正网络的权重,然后执行下次训练,每个一定次数保存当前的网络。
其中损失函数选择的是nll_loss,它输入要求是一个对数概率向量和一个目标标签. 它不会为我们计算对数概率. 适合网络的最后一层是log_softmax.
保存当前网络的方法是,torch::save(net, "net.pt")
#include <iostream>
#include <torch/script.h>
#include <torch/csrc/api/include/torch/torch.h>
// Define a new Module.
struct Net : torch::nn::Module {
Net() {
// Construct and register two Linear submodules.
fc1 = register_module("fc1", torch::nn::Linear(784, 64));
fc2 = register_module("fc2", torch::nn::Linear(64, 32));
fc3 = register_module("fc3", torch::nn::Linear(32, 10));
}
// Implement the Net's algorithm.
torch::Tensor forward(torch::Tensor x) {
// Use one of many tensor manipulation functions.
x = torch::relu(fc1->forward(x.reshape({ x.size(0), 784 })));
x = torch::dropout(x, /*p=*/0.5, /*train=*/is_training());
x = torch::relu(fc2->forward(x));
x = torch::log_softmax(fc3->forward(x), /*dim=*/1);
return x;
}
// Use one of many "standard library" modules.
torch::nn::Linear fc1{ nullptr }, fc2{ nullptr }, fc3{ nullptr };
};
int main()
{
std::cout << "Hello World!\n";
//torch::Tensor tensor = torch::eye(3);
torch::Tensor tensor = torch::rand({ 2,3 });
std::cout << tensor << std::endl;
torch::save(tensor, "tensor.pt");
const static int WIDTH = 512, HEIGHT = 512;
// Create a new Net.
auto net = std::make_shared<Net>();
// Create a multi-threaded data loader for the MNIST dataset.
auto data_loader = torch::data::make_data_loader(
torch::data::datasets::MNIST("../../pytorchCpp/data/mnist/MNIST/raw").map(
torch::data::transforms::Stack<>()),
/*batch_size=*/64);
// Instantiate an SGD optimization algorithm to update our Net's parameters.
torch::optim::SGD optimizer(net->parameters(), /*lr=*/0.01);
std::vector<double> lossVec;
for (size_t epoch = 1; epoch <= 10; ++epoch) {
size_t batch_index = 0;
// Iterate the data loader to yield batches from the dataset.
for (auto& batch : *data_loader) {
// Reset gradients.
optimizer.zero_grad();
// Execute the model on the input data.
torch::Tensor prediction = net->forward(batch.data);
// Compute a loss value to judge the prediction of our model.
torch::Tensor loss = torch::nll_loss(prediction, batch.target);
// Compute gradients of the loss w.r.t. the parameters of our model.
loss.backward();
// Update the parameters based on the calculated gradients.
optimizer.step();
// Output the loss and checkpoint every 100 batches.
if (++batch_index % 100 == 0) {
std::cout << "Epoch: " << epoch << " | Batch: " << batch_index
<< " | Loss: " << loss.item<float>() << std::endl;
lossVec.push_back(loss.item<double>());
// Serialize your model periodically as a checkpoint.
torch::save(net, "net.pt");
}
}
}
}