TensorFlow.NET机器学习入门【5】采用神经网络实现手写数字识别(MNIST)
从这篇文章开始,终于要干点正儿八经的工作了,前面都是准备工作。这次我们要解决机器学习的经典问题,MNIST手写数字识别。
首先介绍一下数据集。请首先解压:TF_Net\Asset\mnist_png.tar.gz文件
文件夹内包括两个文件夹:training和validation,其中training文件夹下包括60000个训练图片validation下包括10000个评估图片,图片为28*28像素,分别放在0~9十个文件夹中。
程序总体流程和上一篇文章介绍的BMI分析程序基本一致,毕竟都是多元分类,有几点不一样。
1、BMI程序的特征数据(输入)为一维数组,包含两个数字,MNIST的特征数据为28*28的二位数组;
2、BMI程序的输出为3个,MNIST的输出为10个;
网络模型构建如下:
private readonly int img_rows = 28; private readonly int img_cols = 28; private readonly int num_classes = 10; // total classes /// <summary> /// 构建网络模型 /// </summary> private Model BuildModel() { // 网络参数 int n_hidden_1 = 128; // 1st layer number of neurons. int n_hidden_2 = 128; // 2nd layer number of neurons. float scale = 1.0f / 255; var model = keras.Sequential(new List<ILayer> { keras.layers.InputLayer((img_rows,img_cols)), keras.layers.Flatten(), keras.layers.Rescaling(scale), keras.layers.Dense(n_hidden_1, activation:keras.activations.Relu), keras.layers.Dense(n_hidden_2, activation:keras.activations.Relu), keras.layers.Dense(num_classes, activation:keras.activations.Softmax) }); return model; }
这个网络里用到了两个新方法,需要解释一下:
1、Flatten方法:这里表示拉平,把28*28的二维数组拉平为含784个数据的一维数组,因为二维数组无法进行运算;
2、Rescaling 方法:就是对每个数据乘以一个系数,因为我们从图片获取的数据为每一个位点的灰度值,其取值范围为0~255,所以乘以一个系数将数据缩小到1以内,以免后面运算时溢出。
其它基本和上一篇文章介绍的差不多,全部代码如下:
/// <summary> /// 神经网络实现手写数字识别 /// </summary> public class NN_MultipleClassification_MNIST { private readonly string train_date_path = @"D:\Study\Blogs\TF_Net\Asset\mnist_png\train_data.bin"; private readonly string train_label_path = @"D:\Study\Blogs\TF_Net\Asset\mnist_png\train_label.bin"; private readonly int img_rows = 28; private readonly int img_cols = 28; private readonly int num_classes = 10; // total classes public void Run() { var model = BuildModel(); model.summary(); model.compile(optimizer: keras.optimizers.Adam(0.001f), loss: keras.losses.SparseCategoricalCrossentropy(), metrics: new[] { "accuracy" }); (NDArray train_x, NDArray train_y) = LoadTrainingData(); model.fit(train_x, train_y, batch_size: 1024, epochs: 10); test(model); } /// <summary> /// 构建网络模型 /// </summary> private Model BuildModel() { // 网络参数 int n_hidden_1 = 128; // 1st layer number of neurons. int n_hidden_2 = 128; // 2nd layer number of neurons. float scale = 1.0f / 255; var model = keras.Sequential(new List<ILayer> { keras.layers.InputLayer((img_rows,img_cols)), keras.layers.Flatten(), keras.layers.Rescaling(scale), keras.layers.Dense(n_hidden_1, activation:keras.activations.Relu), keras.layers.Dense(n_hidden_2, activation:keras.activations.Relu), keras.layers.Dense(num_classes, activation:keras.activations.Softmax) }); return model; } /// <summary> /// 加载训练数据 /// </summary> /// <param name="total_size"></param> private (NDArray, NDArray) LoadTrainingData() { try { Console.WriteLine("Load data"); IFormatter serializer = new BinaryFormatter(); FileStream loadFile = new FileStream(train_date_path, FileMode.Open, FileAccess.Read); float[,,] arrx = serializer.Deserialize(loadFile) as float[,,]; loadFile = new FileStream(train_label_path, FileMode.Open, FileAccess.Read); int[] arry = serializer.Deserialize(loadFile) as int[]; Console.WriteLine("Load data success"); return (np.array(arrx), np.array(arry)); } catch (Exception ex) { Console.WriteLine($"Load data Exception:{ex.Message}"); return LoadRawData(); } } private (NDArray, NDArray) LoadRawData() { Console.WriteLine("LoadRawData"); int total_size = 60000; float[,,] arrx = new float[total_size, img_rows, img_cols]; int[] arry = new int[total_size]; int count = 0; var TrainingImagePath = @"D:\Study\Blogs\TF_Net\Asset\mnist_png\training"; DirectoryInfo RootDir = new DirectoryInfo(TrainingImagePath); foreach (var Dir in RootDir.GetDirectories()) { foreach (var file in Dir.GetFiles("*.png")) { Bitmap bmp = (Bitmap)Image.FromFile(file.FullName); if (bmp.Width != img_cols || bmp.Height != img_rows) { continue; } for (int row = 0; row < img_rows; row++) for (int col = 0; col < img_cols; col++) { var pixel = bmp.GetPixel(col, row); int val = (pixel.R + pixel.G + pixel.B) / 3; arrx[count, row, col] = val; arry[count] = int.Parse(Dir.Name); } count++; } Console.WriteLine($"Load image data count={count}"); } Console.WriteLine("LoadRawData finished"); //Save Data Console.WriteLine("Save data"); IFormatter serializer = new BinaryFormatter(); //开始序列化 FileStream saveFile = new FileStream(train_date_path, FileMode.Create, FileAccess.Write); serializer.Serialize(saveFile, arrx); saveFile.Close(); saveFile = new FileStream(train_label_path, FileMode.Create, FileAccess.Write); serializer.Serialize(saveFile, arry); saveFile.Close(); Console.WriteLine("Save data finished"); return (np.array(arrx), np.array(arry)); } /// <summary> /// 消费模型 /// </summary> private void test(Model model) { var TestImagePath = @"D:\Study\Blogs\TF_Net\Asset\mnist_png\test"; DirectoryInfo TestDir = new DirectoryInfo(TestImagePath); foreach (var image in TestDir.GetFiles("*.png")) { var x = LoadImage(image.FullName); var pred_y = model.Apply(x); var result = argmax(pred_y[0].numpy()); Console.WriteLine($"FileName:{image.Name}\tPred:{result}"); } } private NDArray LoadImage(string filename) { float[,,] arrx = new float[1, img_rows, img_cols]; Bitmap bmp = (Bitmap)Image.FromFile(filename); for (int row = 0; row < img_rows; row++) for (int col = 0; col < img_cols; col++) { var pixel = bmp.GetPixel(col, row); int val = (pixel.R + pixel.G + pixel.B) / 3; arrx[0, row, col] = val; } return np.array(arrx); } private int argmax(NDArray array) { var arr = array.reshape(-1); float max = 0; for (int i = 0; i < 10; i++) { if (arr[i] > max) { max = arr[i]; } } for (int i = 0; i < 10; i++) { if (arr[i] == max) { return i; } } return 0; } }
另有两点说明:
1、由于对图片的读取比较耗时,所以我采用了一个方法,就是把读取到的数据序列化到一个二进制文件中,下次直接从二进制文件反序列化即可,大大加快处理速度。如果找不到bin文件就从图片读取,bin文件我没有上传到git库里,所以下载项目后第一次运行需要一点时间。
2、我没有采用validation图片进行评估,只是简单选了20个样本测试了一下。
【相关资源】
源码:Git: https://gitee.com/seabluescn/tf_not.git
项目名称:NN_MultipleClassification_MNIST
签名区:
如果您觉得这篇博客对您有帮助或启发,请点击右侧【推荐】支持,谢谢!