C++ 读取MNIST手写数字数据集
不是爱风尘,似被前缘误。花落花开自有时,总赖东君主。
去也终须去,住也如何住!若得山花插满头,莫问奴归处。--宋·严蕊
MNIST 数据集可在 http://yann.lecun.com/exdb/mnist/ 获取, 它包含了四个部分:
- Training set images: train-images-idx3-ubyte.gz (9.9 MB, 解压后 47 MB, 包含 60,000 个样本)
- Training set labels: train-labels-idx1-ubyte.gz (29 KB, 解压后 60 KB, 包含 60,000 个标签)
- Test set images: t10k-images-idx3-ubyte.gz (1.6 MB, 解压后 7.8 MB, 包含 10,000 个样本)
- Test set labels: t10k-labels-idx1-ubyte.gz (5KB, 解压后 10 KB, 包含 10,000 个标签)
以train-images-idx3-ubyte为例,结构如下:
[offset] [type] [value] [description]
0000 32 bit integer 0x00000803(2051) magic number
0004 32 bit integer 60000 number of images
0008 32 bit integer 28 number of rows
0012 32 bit integer 28 number of columns
0016 unsigned byte ?? pixel
0017 unsigned byte ?? pixel
........
xxxx unsigned byte ?? pixel
首先读取文件信息magic number、数据集个数、宽、高;
1ifstream f(filename, ios::in | ios::binary); //以二进制方式读取
2
3 char p[4];
4 MatrixXd info(1,4);
5
6 for (int i = 0; i < 4; i++) {
7 f.read(p, 4);
8 for (int j = 0; j < 2; j++)
9 {
10 char tmp ;
11 tmp = p[j];
12 p[j] = p[3 - j];
13 p[3 - j] = tmp;
14 }
15 //num = (p[3])+(p[2] << 8) + (p[1] << 16) + (p[0] << 24);
16 int num = 0;
17 memcpy(&num, p, 4); //连续四个char(1B) 转 int(4B)
18 info(1, i) = num;
19 }
每个照片占 28*28=784字节。每次一次读取784字节,存入ws(vector<MatrixXd>);
1 MatrixXd im(28, 28);
2 char pix[784];
3
4 while (!f.eof())
5 {
6 f.read(pix, 784);
7 for (int i = 0; i < im.rows(); i++)
8 {
9 for (int j = 0; j < im.cols(); j++)
10 {
11 im(i, j) = (unsigned char)pix[i * 28 + j];
12 }
13 }
14 ws.push_back(im);
15 }
16
使用CImg库输出照片,检查读取正确与否:
1 CImg<unsigned char> img(28, 28, 1, 1);
2 img.fill(0);
3
4 for (int i = 0; i < img.width(); i++)
5 {
6 for (int j = 0; j < img.height(); j++)
7 {
8 img.atXYZ(j,i,1) = im(i,j);
9 }
10 }
11
12 cout << im << endl;
13 cout <<"ws.size:" <<ws.size() << endl;
14 img.display("My first CImg code");
输出结果:
照片:
标签集的结构,如下所示:
[offset] [type] [value] [description]
0000 32 bit integer 0x00000801(2049) magic number (MSB first)
0004 32 bit integer 60000 number of items
0008 unsigned byte ?? label
0009 unsigned byte ?? label
........
xxxx unsigned byte ?? label
代码类似,不再赘述。