convert_mnist_data.cpp
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 | // This script converts the MNIST dataset to a lmdb (default) or // leveldb (--backend=leveldb) format used by caffe to load data. // Usage: // convert_mnist_data [FLAGS] input_image_file input_label_file // output_db_file // The MNIST dataset could be downloaded at // http://yann.lecun.com/exdb/mnist/ #include <gflags/gflags.h> #include <glog/logging.h> #include <google/protobuf/text_format.h> #if defined(USE_LEVELDB) && defined(USE_LMDB) #include <leveldb/db.h> #include <leveldb/write_batch.h> #include <lmdb.h> #endif #include <stdint.h> #include <sys/stat.h> #include <fstream> // NOLINT(readability/streams) #include <string> #include "boost/scoped_ptr.hpp" #include "caffe/proto/caffe.pb.h" #include "caffe/util/db.hpp" #include "caffe/util/format.hpp" #if defined(USE_LEVELDB) && defined(USE_LMDB) using namespace caffe; // NOLINT(build/namespaces) using boost::scoped_ptr; using std::string; DEFINE_string(backend, "lmdb" , "The backend for storing the result" ); uint32_t swap_endian(uint32_t val) { val = ((val << 8) & 0xFF00FF00) | ((val >> 8) & 0xFF00FF); return (val << 16) | (val >> 16); } void convert_dataset( const char * image_filename, const char * label_filename, const char * db_path, const string& db_backend) { // Open files std::ifstream image_file(image_filename, std::ios::in | std::ios::binary); std::ifstream label_file(label_filename, std::ios::in | std::ios::binary); CHECK(image_file) << "Unable to open file " << image_filename; CHECK(label_file) << "Unable to open file " << label_filename; // Read the magic and the meta data uint32_t magic; uint32_t num_items; uint32_t num_labels; uint32_t rows; uint32_t cols; image_file.read( reinterpret_cast < char *>(&magic), 4); magic = swap_endian(magic); CHECK_EQ(magic, 2051) << "Incorrect image file magic." ; label_file.read( reinterpret_cast < char *>(&magic), 4); magic = swap_endian(magic); CHECK_EQ(magic, 2049) << "Incorrect label file magic." ; image_file.read( reinterpret_cast < char *>(&num_items), 4); num_items = swap_endian(num_items); label_file.read( reinterpret_cast < char *>(&num_labels), 4); num_labels = swap_endian(num_labels); CHECK_EQ(num_items, num_labels); image_file.read( reinterpret_cast < char *>(&rows), 4); rows = swap_endian(rows); image_file.read( reinterpret_cast < char *>(&cols), 4); cols = swap_endian(cols); scoped_ptr<db::DB> db(db::GetDB(db_backend)); db->Open(db_path, db::NEW); scoped_ptr<db::Transaction> txn(db->NewTransaction()); // Storing to db char label; char * pixels = new char [rows * cols]; int count = 0; string value; Datum datum; datum.set_channels(1); datum.set_height(rows); datum.set_width(cols); LOG(INFO) << "A total of " << num_items << " items." ; LOG(INFO) << "Rows: " << rows << " Cols: " << cols; for ( int item_id = 0; item_id < num_items; ++item_id) { image_file.read(pixels, rows * cols); label_file.read(&label, 1); datum.set_data(pixels, rows*cols); datum.set_label(label); string key_str = caffe::format_int(item_id, 8); datum.SerializeToString(&value); txn->Put(key_str, value); if (++count % 1000 == 0) { txn->Commit(); } } // write the last batch if (count % 1000 != 0) { txn->Commit(); } LOG(INFO) << "Processed " << count << " files." ; delete [] pixels; db->Close(); } int main( int argc, char ** argv) { #ifndef GFLAGS_GFLAGS_H_ namespace gflags = google; #endif FLAGS_alsologtostderr = 1; gflags::SetUsageMessage( "This script converts the MNIST dataset to\n" "the lmdb/leveldb format used by Caffe to load data.\n" "Usage:\n" " convert_mnist_data [FLAGS] input_image_file input_label_file " "output_db_file\n" "The MNIST dataset could be downloaded at\n" " http://yann.lecun.com/exdb/mnist/\n" "You should gunzip them after downloading," "or directly use data/mnist/get_mnist.sh\n" ); gflags::ParseCommandLineFlags(&argc, &argv, true ); const string& db_backend = FLAGS_backend; if (argc != 4) { gflags::ShowUsageWithFlagsRestrict(argv[0], "examples/mnist/convert_mnist_data" ); } else { google::InitGoogleLogging(argv[0]); convert_dataset(argv[1], argv[2], argv[3], db_backend); } return 0; } #else int main( int argc, char ** argv) { LOG(FATAL) << "This example requires LevelDB and LMDB; " << "compile with USE_LEVELDB and USE_LMDB." ; } #endif // USE_LEVELDB and USE_LMDB |
代码中DEFINE_string(backend,"lmdb","the backend for storing the result") 这句采用的gflags工具,为google开源工具,说白了作用就是将backend 这个string类型的变量的默认值为“lamdb”, 在执行没有这个参数的前提下,就使用这个默认值。也可以使用其他比如DEFINE_int64,DEFINE_uint64,DEFINE_bool,DEFINE_double,DEFINE_string等等。
代码中 std::ifstream image_file(image_filename, std::ios::in | std::ios::binary);ifstream表示输入类,image_file为这种对象,std::ios::binary|std::ios::in表示二进制和输入,类似于C中的“rb"
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】凌霞软件回馈社区,博客园 & 1Panel & Halo 联合会员上线
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】博客园社区专享云产品让利特惠,阿里云新客6.5折上折
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步