使用C++解析MNIST数据库

遇到的几个大坑
1.官方主页给出了每个文件的字节数是个玄幻数据,training set images (9912422 bytes) ,这个字节数是解压前的,解压后字节数应该为47,040,016,这个数等于4 + 4 + 4 + 4 + 60000 * 28 * 28。
2.windows下的fgetc是个玄幻函数,以文本方式"r"读取时会错误判断EOF标志,改成"rb",以字节流方式读取即可。
3.virtual studio2017默认的堆栈大小为1MB,若不修改程序会出现栈溢出错误,属性->链接器->系统->修改堆栈大小和堆栈保留大小为100000000。

#include <cstdio>
#include <vector>
#include <cstring>
#include <iostream>
#include <algorithm>
#include <opencv2/opencv.hpp>
using namespace std;
using namespace cv;

const int MnistTrainNumber = 6000;
const int MnistTestNumber = 1000;

//存储像素信息
struct Image
{
	cv::Mat pixs;
	Image()
	{
		pixs.create(Size(28, 28), CV_8U);
	}
};
struct MnistImage
{
	//检验值
	int magicNumber;
	//图片数量
	int number;
	//图片行数
	int rows;
	//图片列数
	int cols;
	//图片数组
	vector<Image> images;
};
struct MnistLabel
{
	//检验值
	int magicNumber;
	//标签数量
	int number;
	//标签数组
	vector<int> labels;
};
//训练集
struct MnistTrainSet
{
	MnistImage trainImages;
	MnistLabel trainLabels;
};
//测试集
struct MnistTestSet
{
	MnistImage trainImages;
	MnistLabel trainLabels;
};
//从file当前指针开始,连续读取length个字节,返回读取到的整数
int readData(FILE *file, int length)
{
	int ans = 0;
	for (int i = 0; i < length; i++)
	{
		ans = ans * 256 + fgetc(file);
	}
	return ans;
}
//解析图片字节流文件
int parseMnistImage(const char *fileName, MnistImage &mnistImage, int imagesNumber)
{
	FILE *out = fopen(fileName, "rb");
	if (out == NULL) return -1;
	mnistImage.magicNumber = readData(out, 4);
	mnistImage.number = readData(out, 4);
	mnistImage.rows = readData(out, 4);
	mnistImage.cols = readData(out, 4);
	for (int k = 0; k < imagesNumber; k++)
	{
		Image image;
		for (int i = 0; i < 28; i++)
		{
			for (int j = 0; j < 28; j++)
			{
				int x = fgetc(out);
				image.pixs.at<uchar>(i, j) = x;
			}
		}
		mnistImage.images.push_back(image);
	}
	fclose(out);
	return mnistImage.magicNumber;
}
//解析标签字节流文件
int parseMnistLabel(const char *fileName, MnistLabel &mnistLabel, int labelNumber)
{
	FILE *out = fopen(fileName, "rb");
	if (out == NULL) return -1;
	mnistLabel.magicNumber = readData(out, 4);
	mnistLabel.number = readData(out, 4);
	for (int i = 0; i < labelNumber; i++)
	{
		int x = fgetc(out);
		mnistLabel.labels.push_back(x);
	}
	fclose(out);
	return mnistLabel.magicNumber;
}
void virtualizeData(Mat &mat)
{
	imshow("virtualizeData", mat);
	waitKey();
}
int main()
{
	MnistTrainSet mnistTrainSet;
	MnistTestSet mnistTestSet;
	//magic number分别为2051,2049,2051,2049,与官方提供的检验值比对以确定解析程序是否有误
	cout << parseMnistImage("train-images.idx3-ubyte", mnistTrainSet.trainImages, MnistTrainNumber) << endl;
	cout << parseMnistLabel("train-labels.idx1-ubyte", mnistTrainSet.trainLabels, MnistTrainNumber) << endl;
	cout << parseMnistImage("t10k-images.idx3-ubyte", mnistTestSet.trainImages, MnistTestNumber) << endl;
	cout << parseMnistLabel("t10k-labels.idx1-ubyte", mnistTestSet.trainLabels, MnistTestNumber) << endl;
	//可视化训练集中第k张图片
	int k = 5;
	virtualizeData(mnistTrainSet.trainImages.images[k].pixs);
	return 0;
}


posted @ 2018-07-16 22:07  技术流选手  阅读(1090)  评论(4编辑  收藏  举报