C++读mnist数据
#include <iostream> #include <fstream> #include <string> #include <vector> using namespace std; int ReverseInt(int i) { unsigned char ch1, ch2, ch3, ch4; ch1 = i & 255; ch2 = (i >> 8) & 255; ch3 = (i >> 16) & 255; ch4 = (i >> 24) & 255; return ((int) ch1 << 24) + ((int) ch2 << 16) + ((int) ch3 << 8) + ch4; } void read_Mnist_Label(string filename, vector<double> &labels) { ifstream file(filename, ios::binary); if (file.is_open()) { int magic_number = 0; int number_of_images = 0; file.read((char *) &magic_number, sizeof(magic_number)); file.read((char *) &number_of_images, sizeof(number_of_images)); magic_number = ReverseInt(magic_number); number_of_images = ReverseInt(number_of_images); // cout << "magic number = " << magic_number << endl; // cout << "number of images = " << number_of_images << endl; for (int i = 0; i < number_of_images; i++) { unsigned char label = 0; file.read((char *) &label, sizeof(label)); labels.push_back((double) label); } } } void read_Mnist_Images(string filename, vector<vector<double>> &images) { ifstream file(filename, ios::binary); if (file.is_open()) { int magic_number = 0; int number_of_images = 0; int n_rows = 0; int n_cols = 0; unsigned char label; file.read((char *) &magic_number, sizeof(magic_number)); file.read((char *) &number_of_images, sizeof(number_of_images)); file.read((char *) &n_rows, sizeof(n_rows)); file.read((char *) &n_cols, sizeof(n_cols)); magic_number = ReverseInt(magic_number); number_of_images = ReverseInt(number_of_images); n_rows = ReverseInt(n_rows); n_cols = ReverseInt(n_cols); // cout << "magic number = " << magic_number << endl; // cout << "number of images = " << number_of_images << endl; // cout << "rows = " << n_rows << endl; // cout << "cols = " << n_cols << endl; for (int i = 0; i < number_of_images; i++) { vector<double> tp; for (int r = 0; r < n_rows; r++) { for (int c = 0; c < n_cols; c++) { unsigned char image = 0; file.read((char *) &image, sizeof(image)); tp.push_back(image); } } images.push_back(tp); } } } //int main01() { // vector<double>labels; // read_Mnist_Label("train-labels.idx1-ubyte", labels); // for (auto iter = labels.begin(); iter != labels.end(); iter++) // { // cout << *iter << " "; // } // // vector<vector<double>> images; // read_Mnist_Images("train-images.idx3-ubyte", images); // cout << images.size(); // for (int i = 0; i < images.size(); i++) { // for (int j = 0; j < images[0].size(); j++) { // cout << images[i][j] << " "; // } // } // // return 0; //}