caffe: test code for Deep Learning approach

  1 #include <stdio.h>  // for snprintf
  2 #include <string>
  3 #include <vector>
  4 
  5 #include "boost/algorithm/string.hpp"
  6 #include "google/protobuf/text_format.h"
  7 
  8 #include "caffe/blob.hpp"
  9 #include "caffe/common.hpp"
 10 #include "caffe/net.hpp"
 11 #include "caffe/proto/caffe.pb.h"
 12 #include "caffe/util/db.hpp"
 13 #include "caffe/util/io.hpp"
 14 #include "caffe/vision_layers.hpp"
 15 
 16 using caffe::Blob;
 17 using caffe::Caffe;
 18 using caffe::Datum;
 19 using caffe::Net;
 20 using boost::shared_ptr;
 21 using std::string;
 22 namespace db = caffe::db;
 23 
 24 template<typename Dtype>
 25 int feature_extraction_pipeline(int argc, char** argv);
 26 
 27 int main(int argc, char** argv) {
 28   return feature_extraction_pipeline<float>(argc, argv);
 29 //  return feature_extraction_pipeline<double>(argc, argv);
 30 }
 31 
 32 template<typename Dtype>
 33 int feature_extraction_pipeline(int argc, char** argv) {
 34   ::google::InitGoogleLogging(argv[0]);
 35   const int num_required_args = 7;
 36   if (argc < num_required_args) {
 37     LOG(ERROR)<<
 38     "This program takes in a trained network and an input data layer, and then"
 39     " extract features of the input data produced by the net.\n"
 40     "Usage: extract_features  pretrained_net_param"
 41     "  feature_extraction_proto_file  extract_feature_blob_name1[,name2,...]"
 42     "  save_feature_dataset_name1[,name2,...]  num_mini_batches  db_type"
 43     "  [CPU/GPU] [DEVICE_ID=0]\n"
 44     "Note: you can extract multiple features in one pass by specifying"
 45     " multiple feature blob names and dataset names separated by ','."
 46     " The names cannot contain white space characters and the number of blobs"
 47     " and datasets must be equal.";
 48     return 1;
 49   }
 50   int arg_pos = num_required_args;
 51 
 52   arg_pos = num_required_args;
 53   if (argc > arg_pos && strcmp(argv[arg_pos], "GPU") == 0) {
 54     LOG(ERROR)<< "Using GPU";
 55     uint device_id = 0;
 56     if (argc > arg_pos + 1) {
 57       device_id = atoi(argv[arg_pos + 1]);
 58       CHECK_GE(device_id, 0);
 59     }
 60     LOG(ERROR) << "Using Device_id=" << device_id;
 61     Caffe::SetDevice(device_id);
 62     Caffe::set_mode(Caffe::GPU);
 63   } else {
 64     LOG(ERROR) << "Using CPU";
 65     Caffe::set_mode(Caffe::CPU);
 66   }
 67 
 68   arg_pos = 0;  // the name of the executable
 69   std::string pretrained_binary_proto(argv[++arg_pos]);
 70 
 71   // Expected prototxt contains at least one data layer such as
 72   //  the layer data_layer_name and one feature blob such as the
 73   //  fc7 top blob to extract features.
 74   /*
 75    layers {
 76      name: "data_layer_name"
 77      type: DATA
 78      data_param {
 79        source: "/path/to/your/images/to/extract/feature/images_leveldb"
 80        mean_file: "/path/to/your/image_mean.binaryproto"
 81        batch_size: 128
 82        crop_size: 227
 83        mirror: false
 84      }
 85      top: "data_blob_name"
 86      top: "label_blob_name"
 87    }
 88    layers {
 89      name: "drop7"
 90      type: DROPOUT
 91      dropout_param {
 92        dropout_ratio: 0.5
 93      }
 94      bottom: "fc7"
 95      top: "fc7"
 96    }
 97    */
 98   std::string feature_extraction_proto(argv[++arg_pos]);
 99   shared_ptr<Net<Dtype> > feature_extraction_net(
100       new Net<Dtype>(feature_extraction_proto, caffe::TEST));
101   feature_extraction_net->CopyTrainedLayersFrom(pretrained_binary_proto);
102 
103   std::string extract_feature_blob_names(argv[++arg_pos]);
104   std::vector<std::string> blob_names;
105   boost::split(blob_names, extract_feature_blob_names, boost::is_any_of(","));
106 
107   std::string save_feature_dataset_names(argv[++arg_pos]);
108   std::vector<std::string> dataset_names;
109   boost::split(dataset_names, save_feature_dataset_names,
110                boost::is_any_of(","));
111   CHECK_EQ(blob_names.size(), dataset_names.size()) <<
112       " the number of blob names and dataset names must be equal";
113   size_t num_features = blob_names.size();
114 
115   for (size_t i = 0; i < num_features; i++) {
116     CHECK(feature_extraction_net->has_blob(blob_names[i]))
117         << "Unknown feature blob name " << blob_names[i]
118         << " in the network " << feature_extraction_proto;
119   }
120 
121   int num_mini_batches = atoi(argv[++arg_pos]);
122 
123   std::vector<shared_ptr<db::DB> > feature_dbs;
124   std::vector<shared_ptr<db::Transaction> > txns;
125   const char* db_type = argv[++arg_pos];
126   for (size_t i = 0; i < num_features; ++i) {
127     LOG(INFO)<< "Opening dataset " << dataset_names[i];
128     shared_ptr<db::DB> db(db::GetDB(db_type));
129     db->Open(dataset_names.at(i), db::NEW);
130     feature_dbs.push_back(db);
131     shared_ptr<db::Transaction> txn(db->NewTransaction());
132     txns.push_back(txn);
133   }
134 
135   LOG(ERROR)<< "Extacting Features";
136 
137   Datum datum;
138   const int kMaxKeyStrLength = 100;
139   char key_str[kMaxKeyStrLength];
140   std::vector<Blob<float>*> input_vec;
141   std::vector<int> image_indices(num_features, 0);
142   for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index) {
143     feature_extraction_net->Forward(input_vec);
144     for (int i = 0; i < num_features; ++i) {
145       const shared_ptr<Blob<Dtype> > feature_blob = feature_extraction_net
146           ->blob_by_name(blob_names[i]);
147       int batch_size = feature_blob->num();
148       int dim_features = feature_blob->count() / batch_size;
149       const Dtype* feature_blob_data;
150       for (int n = 0; n < batch_size; ++n) {
151         datum.set_height(feature_blob->height());
152         datum.set_width(feature_blob->width());
153         datum.set_channels(feature_blob->channels());
154         datum.clear_data();
155         datum.clear_float_data();
156         feature_blob_data = feature_blob->cpu_data() +
157             feature_blob->offset(n);
158         for (int d = 0; d < dim_features; ++d) {
159           datum.add_float_data(feature_blob_data[d]);
160         }
161         int length = snprintf(key_str, kMaxKeyStrLength, "%010d",
162             image_indices[i]);
163         string out;
164         CHECK(datum.SerializeToString(&out));
165         txns.at(i)->Put(std::string(key_str, length), out);
166         ++image_indices[i];
167         if (image_indices[i] % 1000 == 0) {
168           txns.at(i)->Commit();
169           txns.at(i).reset(feature_dbs.at(i)->NewTransaction());
170           LOG(ERROR)<< "Extracted features of " << image_indices[i] <<
171               " query images for feature blob " << blob_names[i];
172         }
173       }  // for (int n = 0; n < batch_size; ++n)
174     }  // for (int i = 0; i < num_features; ++i)
175   }  // for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index)
176   // write the last batch
177   for (int i = 0; i < num_features; ++i) {
178     if (image_indices[i] % 1000 != 0) {
179       txns.at(i)->Commit();
180     }
181     LOG(ERROR)<< "Extracted features of " << image_indices[i] <<
182         " query images for feature blob " << blob_names[i];
183     feature_dbs.at(i)->Close();
184   }
185 
186   LOG(ERROR)<< "Successfully extracted the features!";
187   return 0;
188 }
View Code

 

posted @ 2016-02-19 10:33  AHU-WangXiao  阅读(384)  评论(0编辑  收藏  举报