Caffe::Snapshot的运行过程

Snapshot的存储

概述

Snapshot的存储格式有两种,分别是BINARYPROTO格式和hdf5格式。BINARYPROTO是一种二进制文件,并且可以通过修改shapshot_format来设置存储类型。该项的默认是BINARYPROTO不管哪种格式,运行的过程是类似的,都是从Solver<Dtype>::Snapshot()函数进入,首先调用Net网络的方法,再操作网络中的每一层,最后再操作每一层中blob,最后调用write函数写入输出。源码入口:

 1 void Solver<Dtype>::Snapshot() {
 2   CHECK(Caffe::root_solver());
 3   string model_filename;
 4   switch (param_.snapshot_format()) {
 5   case caffe::SolverParameter_SnapshotFormat_BINARYPROTO:
 6     model_filename = SnapshotToBinaryProto();
 7     break;
 8   case caffe::SolverParameter_SnapshotFormat_HDF5:
 9     model_filename = SnapshotToHDF5();
10     break;
11   default:
12     LOG(FATAL) << "Unsupported snapshot format.";
13   }

 

BINARYPROTO格式

如果是BINARYPROTO的存储格式,就执行如下代码:

1 string Solver<Dtype>::SnapshotToBinaryProto() {
2   string model_filename = SnapshotFilename(".caffemodel");
3   LOG(INFO) << "Snapshotting to binary proto file " << model_filename;
4   NetParameter net_param;
5   net_->ToProto(&net_param, param_.snapshot_diff());
6   WriteProtoToBinaryFile(net_param, model_filename);
7   return model_filename;
8 }   

 

首先会执行SnapshotFilename(“.caffemodel”)函数,识别出sovler.prototxt文件中snapshot_prefix的内容,作用该snapshot文件的文件名前缀。然后调用net_->ToProto(),具体的代码如下:

 1 void Net<Dtype>::ToProto(NetParameter* param, bool write_diff) const {
 2   param->Clear();
 3   param->set_name(name_);
 4   for (int i = 0; i < net_input_blob_indices_.size(); ++i) {
 5     param->add_input(blob_names_[net_input_blob_indices_[i]]);
 6   }
 7   for (int i = 0; i < layers_.size(); ++i) {
 8     LayerParameter* layer_param = param->add_layer();
 9     layers_[i]->ToProto(layer_param, write_diff);
10   }
11 }  

 

获取到网络中的每层的名字等参数后,调用layers_[i]->ToProto()每一层的ToProto方法,接下来

1 void Layer<Dtype>::ToProto(LayerParameter* param, bool write_diff) {
2   param->Clear();
3   param->CopyFrom(layer_param_);
4   param->clear_blobs();
5   for (int i = 0; i < blobs_.size(); ++i) {
6     blobs_[i]->ToProto(param->add_blobs(), write_diff);
7   }
8 } 

然后调用当前层下的所有blobToProto方法,即:

 1 void Blob<double>::ToProto(BlobProto* proto, bool write_diff) const {
 2   proto->clear_shape();
 3   for (int i = 0; i < shape_.size(); ++i) {
 4     proto->mutable_shape()->add_dim(shape_[i]);
 5   }
 6   proto->clear_double_data();
 7   proto->clear_double_diff();
 8   const double* data_vec = cpu_data();
 9   for (int i = 0; i < count_; ++i) {
10     proto->add_double_data(data_vec[i]);
11   }
12   if (write_diff) {
13     const double* diff_vec = cpu_diff();
14     for (int i = 0; i < count_; ++i) {
15       proto->add_double_diff(diff_vec[i]);
16     }
17   }

 

在每一个blob中,会调用add_double_data()函数,把data添加到snapshot文件中,同时会判断是否当前blob参与diff的计算,如果需要当前blob需要diff参数,就调用add_double_diff()添加到snapshot文件中。

调用完所有的blobToProto()方法后,会执行WriteProtoToBinaryFile()把该文件写出即可。

1 void WriteProtoToBinaryFile(const Message& proto, const char* filename) {
2   fstream output(filename, ios::out | ios::trunc | ios::binary);
3   CHECK(proto.SerializeToOstream(&output));
4 }

在该方法里调用FStreamoutput方法进行输出。

Hdf5格式

Hdf5格式的运行过程和BINARYPROTO格式的过程类似,首先会调用SnapshotToHDF5()函数,即:

1 string Solver<Dtype>::SnapshotToHDF5() {
2   string model_filename = SnapshotFilename(".caffemodel.h5");
3   LOG(INFO) << "Snapshotting to HDF5 file " << model_filename;
4   net_->ToHDF5(model_filename, param_.snapshot_diff());
5   return model_filename;
6 }

首先会执行SnapshotFilename(“.caffemodel.h5”)函数,识别出sovler.prototxt文件中snapshot_prefix的内容,作用该snapshot文件的文件名前缀。然后调用net_->ToHDF5(),即:

 1 void Net<Dtype>::ToHDF5(const string& filename, bool write_diff) const {
 2   hid_t file_hid = H5Fcreate(filename.c_str(), H5F_ACC_TRUNC, H5P_DEFAULT,
 3       H5P_DEFAULT);
 4   hid_t data_hid = H5Gcreate2(file_hid, "data", H5P_DEFAULT, H5P_DEFAULT,
 5       H5P_DEFAULT);
 6     hid_t diff_hid = -1;
 7   if (write_diff) {
 8     diff_hid = H5Gcreate2(file_hid, "diff", H5P_DEFAULT, H5P_DEFAULT,
 9         H5P_DEFAULT);
10    }
11   for (int layer_id = 0; layer_id < layers_.size(); ++layer_id) {
12     const LayerParameter& layer_param = layers_[layer_id]->layer_param();
13     string layer_name = layer_param.name();
14     hid_t layer_data_hid = H5Gcreate2(data_hid, layer_name.c_str(),
15     hid_t layer_diff_hid = -1;
16     if (write_diff) {
17       layer_diff_hid = H5Gcreate2(diff_hid, layer_name.c_str(),
18           H5P_DEFAULT, H5P_DEFAULT, H5P_DEFAULT);  
19  }
20     int num_params = layers_[layer_id]->blobs().size();
21     for (int param_id = 0; param_id < num_params; ++param_id) {
22       ostringstream dataset_name;
23       dataset_name << param_id;
24       const int net_param_id = param_id_vecs_[layer_id][param_id];
25       if (param_owners_[net_param_id] == -1) {
26         hdf5_save_nd_dataset<Dtype>(layer_data_hid, dataset_name.str(),
27             *params_[net_param_id]);
28       }
29       if (write_diff) {
30         hdf5_save_nd_dataset<Dtype>(layer_diff_hid, dataset_name.str(),
31             *params_[net_param_id], true);
32       }
33 ...............
34 H5Fclose(file_hid);
35 }

该函数首先调用H5Fcreate()创建一个file文件,然后循环调用每一层,通过调用每一层的H5Gcreate2函数记录出该层的data_hid或者diff_hid(如果该层需要参与计算),然后进入每一层内部的blob,然后在当前blob内调用hdf5_save_nd_dataset()hdf5_save_nd_dataset()(如果当前blob需要参与计算diff),将data添加到hdf5格式的文件中,最后调用H5Fclose(file_hid)函数,输出该文件。

 

Snapshot的恢复

概述

想在已经训练好的网络上继续训练,那么需要调用Restore()方法从snapshot的文件中恢复成网络,从而缩短了训练时间。方法的入口是Solver<Dtype>::Restore(const char* state_file)函数,即:

1 void Solver<Dtype>::Restore(const char* state_file) {
2   CHECK(Caffe::root_solver());
3   string state_filename(state_file);
4   if (state_filename.size() >= 3 &&
5       state_filename.compare(state_filename.size() - 3, 3, ".h5") == 0) {
6     RestoreSolverStateFromHDF5(state_filename);
7   } else {
8     RestoreSolverStateFromBinaryProto(state_filename);
9   }

该函数会解析snapshot文件是BINARYPROTO格式还是Hdf5格式,如果是BINARYPROTO格式的话就调用RestoreSolverStateFromBinaryProto()函数,如果格式Hdf5的格式,就执行RestoreSolverStateFromHDF5()

BINARYPROOTO格式

如果是BINARYPROTO格式,则执行下列代码:

 

 1 void SGDSolver<Dtype>::RestoreSolverStateFromBinaryProto(
 2     const string& state_file) {
 3   SolverState state;
 4   ReadProtoFromBinaryFile(state_file, &state);
 5   this->iter_ = state.iter();
 6   if (state.has_learned_net()) {
 7     NetParameter net_param;
 8     ReadNetParamsFromBinaryFileOrDie(state.learned_net().c_str(), &net_param);
 9     this->net_->CopyTrainedLayersFrom(net_param);
10   }
11   this->current_step_ = state.current_step();
12   CHECK_EQ(state.history_size(), history_.size())
13       << "Incorrect length of history blobs.";
14   for (int i = 0; i < history_.size(); ++i) {
15     history_[i]->FromProto(state.history(i));
16   }
17 }

 

该函数会大量调用googleprotobuf包内的函数,首先会通过ReadProtoFromBinaryFile()函数读取BINARYPROTO格式的文件来返回是否可以成功读取。然后判断该snapshot是否有曾经训练过的网络,如果有,则调用函数ReadNetParamsFromBinaryFileOrDie()读取出该Net网络,然后调用函数CopyTrainedLayersFrom(net_param)具体恢复该网络的每一层以及当前层内的所有blob,具体数据恢复的工作就是CopyTrainedLayersFrom()函数内部变量调用FromProto()函数来实现blob复制的。然后会通过函数current_step()来判断上次训练的位置(迭代到多少次),然后通过循环把训练过的data数据通过FromProto()完成数据的复制。

Hdf5格式

 1 void SGDSolver<Dtype>::RestoreSolverStateFromHDF5(const string& state_file) {
 2   hid_t file_hid = H5Fopen(state_file.c_str(), H5F_ACC_RDONLY, H5P_DEFAULT);
 3   CHECK_GE(file_hid, 0) << "Couldn't open solver state file " << state_file;
 4   this->iter_ = hdf5_load_int(file_hid, "iter");
 5   if (H5LTfind_dataset(file_hid, "learned_net")) {
 6     string learned_net = hdf5_load_string(file_hid, "learned_net");
 7     this->net_->CopyTrainedLayersFrom(learned_net);
 8   }
 9   this->current_step_ = hdf5_load_int(file_hid, "current_step");
10   hid_t history_hid = H5Gopen2(file_hid, "history", H5P_DEFAULT);
11   CHECK_GE(history_hid, 0) << "Error reading history from " << state_file;
12   int state_history_size = hdf5_get_num_links(history_hid);
13   CHECK_EQ(state_history_size, history_.size())
14       << "Incorrect length of history blobs.";
15   for (int i = 0; i < history_.size(); ++i) {
16     ostringstream oss;
17     oss << i;
18     hdf5_load_nd_dataset<Dtype>(history_hid, oss.str().c_str(), 0,
19                                 kMaxBlobAxes, history_[i].get());
20   }
21   H5Gclose(history_hid);
22   H5Fclose(file_hid);
23 }

该函数会识别hdf5格式存储的snapshot文件的file_hid编号,会判断是否存在之前训练过的网络,如果存在则执行CopyTrainedLayersFrom()函数,完成网络的每层以及每层内的blob的数据的恢复复制,然后或取上一次的训练位置(进行的迭代),并且调用函数hdf5_load_nd_dataset()具体把每次迭代的数据恢复复制,最后再调用H5Fclose()关闭。

 

 

 

 

 

posted @ 2017-08-10 16:59  liurio  阅读(1545)  评论(0编辑  收藏  举报