使用C++解析MNIST数据库
遇到的几个大坑
1.官方主页给出了每个文件的字节数是个玄幻数据,training set images (9912422 bytes) ,这个字节数是解压前的,解压后字节数应该为47,040,016,这个数等于4 + 4 + 4 + 4 + 60000 * 28 * 28。
2.windows下的fgetc是个玄幻函数,以文本方式"r"读取时会错误判断EOF标志,改成"rb",以字节流方式读取即可。
3.virtual studio2017默认的堆栈大小为1MB,若不修改程序会出现栈溢出错误,属性->链接器->系统->修改堆栈大小和堆栈保留大小为100000000。
#include <cstdio>
#include <vector>
#include <cstring>
#include <iostream>
#include <algorithm>
#include <opencv2/opencv.hpp>
using namespace std;
using namespace cv;
const int MnistTrainNumber = 6000;
const int MnistTestNumber = 1000;
//存储像素信息
struct Image
{
cv::Mat pixs;
Image()
{
pixs.create(Size(28, 28), CV_8U);
}
};
struct MnistImage
{
//检验值
int magicNumber;
//图片数量
int number;
//图片行数
int rows;
//图片列数
int cols;
//图片数组
vector<Image> images;
};
struct MnistLabel
{
//检验值
int magicNumber;
//标签数量
int number;
//标签数组
vector<int> labels;
};
//训练集
struct MnistTrainSet
{
MnistImage trainImages;
MnistLabel trainLabels;
};
//测试集
struct MnistTestSet
{
MnistImage trainImages;
MnistLabel trainLabels;
};
//从file当前指针开始,连续读取length个字节,返回读取到的整数
int readData(FILE *file, int length)
{
int ans = 0;
for (int i = 0; i < length; i++)
{
ans = ans * 256 + fgetc(file);
}
return ans;
}
//解析图片字节流文件
int parseMnistImage(const char *fileName, MnistImage &mnistImage, int imagesNumber)
{
FILE *out = fopen(fileName, "rb");
if (out == NULL) return -1;
mnistImage.magicNumber = readData(out, 4);
mnistImage.number = readData(out, 4);
mnistImage.rows = readData(out, 4);
mnistImage.cols = readData(out, 4);
for (int k = 0; k < imagesNumber; k++)
{
Image image;
for (int i = 0; i < 28; i++)
{
for (int j = 0; j < 28; j++)
{
int x = fgetc(out);
image.pixs.at<uchar>(i, j) = x;
}
}
mnistImage.images.push_back(image);
}
fclose(out);
return mnistImage.magicNumber;
}
//解析标签字节流文件
int parseMnistLabel(const char *fileName, MnistLabel &mnistLabel, int labelNumber)
{
FILE *out = fopen(fileName, "rb");
if (out == NULL) return -1;
mnistLabel.magicNumber = readData(out, 4);
mnistLabel.number = readData(out, 4);
for (int i = 0; i < labelNumber; i++)
{
int x = fgetc(out);
mnistLabel.labels.push_back(x);
}
fclose(out);
return mnistLabel.magicNumber;
}
void virtualizeData(Mat &mat)
{
imshow("virtualizeData", mat);
waitKey();
}
int main()
{
MnistTrainSet mnistTrainSet;
MnistTestSet mnistTestSet;
//magic number分别为2051,2049,2051,2049,与官方提供的检验值比对以确定解析程序是否有误
cout << parseMnistImage("train-images.idx3-ubyte", mnistTrainSet.trainImages, MnistTrainNumber) << endl;
cout << parseMnistLabel("train-labels.idx1-ubyte", mnistTrainSet.trainLabels, MnistTrainNumber) << endl;
cout << parseMnistImage("t10k-images.idx3-ubyte", mnistTestSet.trainImages, MnistTestNumber) << endl;
cout << parseMnistLabel("t10k-labels.idx1-ubyte", mnistTestSet.trainLabels, MnistTestNumber) << endl;
//可视化训练集中第k张图片
int k = 5;
virtualizeData(mnistTrainSet.trainImages.images[k].pixs);
return 0;
}