C++基于文件流和armadillo读取mnist

发现网上大把都是用python读取mnist的,用C++大都是用opencv读取的,但我不怎么用opencv,因此自己摸索了个使用文件流读取mnist的方法,armadillo仅作为储存矩阵的一种方式。

1. mnist文件

首先避坑,这些文件要解压。
没有肾么描述
官网截图可知,文件头很简单,只有若干个32位整数,MSB,像素和标签均是无符号字节(即unsigned char)可以先读取文件头,再读取剩下的部分。

2. 读取文件头

我觉得没什么必要啊,直接跳过不行吗
文件头都是32位,那就整四个unsigned char呗。

	FILE *File = fopen(fileName, "r");
	fseek(File, 0, 0);
	uchar a[4];
	fread(a, 4, 1, File);

这样a字符串就保存了一个整数。

	x = ((((a[0] * 256) + a[1]) * 256) + a[2]) * 256 + a[3];

然后就得到了呗。
看每个文件有多少文件头,就操作几次(并可以顺便与官方的magic number进行对比),剩下的就是文件的内容了。

3. 读取内容

这部分可以依照之前的方法,一次读取一个字符,再保存至矩阵当中。例如:

uchar a;
mat image(28, 28, fill::zeros); // 这是个矩阵!
for(int i = 0; i < 28; i++) //28行28列的图像懒得改了
	for(int j = 0; j < 28; j++)
	{
		fread(a, 1, 1, File);
		image(i, j) = double(a);
	}

这样就读取了一张图片。其余以此类推吧。

4. 完整代码

可以复制,可以修改,也可以用于商业和学术,但是请标注原作者(就是我)。
mnist.h

#ifndef MNIST_H  
#define MNIST_H
#include<iostream>
#include<fstream>
#include<armadillo>
#include<vector>
#include<cstdio>

#define uchar unsigned char

using namespace std;
using namespace arma;

//小端存储转换
int reverseInt(uchar *a);

//读取image数据集信息
mat read_mnist_image(const char *fileName);

//读取label数据集信息
mat read_mnist_label(const char* fileName);
#endif

mnist.cpp

//mnist.cpp
//作者:C艹
#include "mnist.h"

int reverseInt(uchar *a)
{
	return ((((a[0] * 256) + a[1]) * 256) + a[2]) * 256 + a[3];
}

mat read_mnist_image(const char *fileName)
{
	FILE *File = fopen(fileName, "r");
	fseek(File, 0, 0);
	mat image;
	uchar a[4];
	fread(a, 4, 1, File);
	int magic = reverseInt(a);
	if (magic != 2051) //magic number wrong
	{
		cout << magic;
		return mat(0, 0, fill::zeros);
	}
	fread(a, 4, 1, File);
	const int num_img = reverseInt(a);
	fread(a, 4, 1, File);
	const int num_row = reverseInt(a);
	fread(a, 4, 1, File);
	const int num_col = reverseInt(a);
	const int size_img = num_col * num_row;
	// 文件头读取完毕
	image.reshape(num_img, size_img);
	uchar img[784];
	for (int i = 0; i < num_img; i++)
	{
		fseek(File, i*784+16, SEEK_SET);
		fread(img, size_img, 1, File);
		for (int j = 0; j < size_img; j++)
		{
			image(i, j) = double(img[j])/256;
		}
	}
	fclose(File);
	return image;
}

mat read_mnist_label(const char *fileName)
{
	FILE* File = fopen(fileName, "r");
	fseek(File, 0, 0);
	uchar a[4];
	fread(a, 4, 1, File);
	int magic = reverseInt(a);
	if (magic != 2049) //magic number wrong
	{
		cout << magic;
		return mat(0, 0, fill::zeros);
	}
	fread(a, 4, 1, File);
	const int num_lab = reverseInt(a);
	// 文件头读取完毕
	mat label(num_lab, 10, fill::zeros);
	uchar lab[1];
	for (int i = 0; i < num_lab; i++)
	{
		fread(lab, 1, 1, File);
		label(i, int(lab[0])) = 1;
	}
	fclose(File);
	return label;
}


posted @ 2021-05-09 16:14  c艹用户  阅读(312)  评论(0编辑  收藏  举报