如何使用 libtorch 实现 LeNet 网络?

如何使用 libtorch 实现 LeNet 网络?

LeNet 网络论文地址:
http://yann.lecun.com/exdb/publis/pdf/lecun-01a.pdf

LeNet

C1 卷积层

{1,1,28,28} 是什么?

1 输入的批次
1 图像的通道大小
28 图像的高
28 图像的宽

输入:{1,1,28,28}

通过填充一个边界 2 ,使得输入变成 {1,1,32,32}

滑动窗口大小:{5,5}

输出:{1,6,32,32}

S2 降采样

输入:{1,6,32,32}

滑动窗口大小:{2,2,}
滑动步长:{2,2}

输出:{1,6,14,14}

C3 卷积层

输入:{1,16,14,14}

滑动窗口大小:{5,5}

输出:{1,16,10,10}

S4 降采样

输入:{1,16,10,10}

滑动窗口大小:{2,2,}
滑动步长:{2,2}

输出:{1,16,5,5}

C5 卷积层

输入:{1,16,5,5}

滑动窗口大小:{5,5}

输出:{1,120,1,1}

F6 全连接层

这里要把网络形状从 {1,120,1,1} 改变改变成 {1,120}

第一个全连接
输入:{1,120}
输出:{1,84}

第二个全连接
输入:{1,84}
输出:{84,10}

0~9 总共是 10 个类别嘛,这里就输出 10个就行了。

全连接就是线性层,网络形状不一样不能全连接的,所以这里要把形状改变成一样的。
基本按照那图写一遍就明白了。

关于输入和输出的网络推断公式可以去参考 pytorch 里面的函数说明,上面都有写推断公式滴。

// Define a new Module.
struct Net : torch::nn::Module {
	Net() {
		conv1 = register_module("conv1", torch::nn::Conv2d(torch::nn::Conv2dOptions(1, 6, /*kernel_size*/{ 5,5 }).padding(/*28->32*/{2,2})));
		conv2 = register_module("conv2", torch::nn::Conv2d(torch::nn::Conv2dOptions(6, 16, /*kernel_size*/{5,5})));
		conv3 = register_module("conv3", torch::nn::Conv2d(torch::nn::Conv2dOptions(16, 120, /*kernel_size*/{5,5})));
		fc1 = register_module("fc1", torch::nn::Linear(torch::nn::LinearOptions(120, 84)));
		fc2 = register_module("fc2", torch::nn::Linear(torch::nn::LinearOptions(84, 10)));
	}

	// Implement the Net's algorithm.
	torch::Tensor forward(torch::Tensor x) {
		x = conv1->forward(x);//6@28x28
		x = torch::max_pool2d(x, { 2,2 }, { 2,2 });//6@14x14
		x = conv2->forward(x);//16@10x10
		x = torch::max_pool2d(x, { 2,2 }, { 2,2 });//16@10x10
		
		x = conv3->forward(x);//120@1x1
		x = x.view({ x.size(0),-1 });//-1 表示自动推理计算出该值
		x = fc1->forward(x);//120->84
		x = fc2->forward(x);//84->10
		x = torch::log_softmax(x,/*dim=*/1);
		return x;
	}

	// Use one of many "standard library" modules.
	torch::nn::Conv2d conv1 { nullptr };
	torch::nn::Conv2d conv2 { nullptr };
	torch::nn::Conv2d conv3 { nullptr };
	torch::nn::Linear fc1{ nullptr };
	torch::nn::Linear fc2{ nullptr };
};
posted @ 2019-04-15 15:34  學海無涯  阅读(731)  评论(0编辑  收藏  举报