用Mnist数据集训练一个手写数字识别网络

Mnist数据集我找了半天才在哔哩哔哩找到一个下载链接,现在的网络下载文件太麻烦了。数据集中的文件格式参考如下链接:

我学习了两种方法。第一种是传统的BP神经网络模式;第二种是LeNet。这些代码已放在gitee上开源。

一、传统方法

开源地址是:Mnist: 基于libtorch1.13.1实现的Mnist数据集手写数字识别程序 (gitee.com),内附了Mnist数据集。代码中使用的是交叉熵损失函数,该函数的介绍可到官网上查找。链接在libtorch的头文件的注释里有。需要注意这个损失函数不需要输入归一化网络输出数据,即在你自定义的网络最后不需要加Softmax层使每个分类的数值之和为1。如下图:

本程序经测试在测试集10000张图片中平均识别率在98%左右。见下图:

二、LeNet网络

这个开源链接是:LeNet: 用libtorch1.13.1实现的LeNet (gitee.com)。网络结构参考自:Pytorch实现卷积神经网络(一) - 知乎 (zhihu.com)。这个网络使用的是Adam优化器,此优化器学习率需要设置的比SGD优化器小一些,否则可能会难以收敛。经对比卷积神经网络的速度明显比普通网络慢一些。训练效果如下图,比前面一种好一点。这个正确率在99%左右:

 

posted @ 2024-03-19 08:41  兜尼完  阅读(23)  评论(0编辑  收藏  举报