Caffe:深入分析(怎么训练)
main()
首先入口函数caffe.cpp
1 int main(int argc, char** argv) { 2 ...... 3 if (argc == 2) { 4 #ifdef WITH_PYTHON_LAYER 5 try { 6 #endif 7 return GetBrewFunction(caffe::string(argv[1]))(); //根据输入参数确定是train还是test,采用string到函数指针的映射实现,非常巧妙 8 #ifdef WITH_PYTHON_LAYER 9 } catch (bp::error_already_set) { 10 PyErr_Print(); 11 return 1; 12 } 13 #endif 14 } else { 15 gflags::ShowUsageWithFlagsRestrict(argv[0], "tools/caffe"); 16 } 17 }
在main函数中GetBrewFunction函数调用了通过工厂模式生成的由string到函数指针的map
1 typedef int (*BrewFunction)(); 2 typedef std::map<caffe::string, BrewFunction> BrewMap; 3 BrewMap g_brew_map;
在train、test、device_query、time函数后面都可以看到对这些函数的register,相当于这些函数指针已经在map中存在了
1 RegisterBrewFunction(train); 2 RegisterBrewFunction(test); 3 RegisterBrewFunction(device_query); 4 RegisterBrewFunction(time);
train()
接着是train过程
1 // Train / Finetune a model. 2 int train() { 3 ...... 4 caffe::SolverParameter solver_param; 5 caffe::ReadSolverParamsFromTextFileOrDie(FLAGS_solver, &solver_param);//从-solver参数读取solver_param 6 ...... 7 shared_ptr<caffe::Solver<float> > 8 solver(caffe::SolverRegistry<float>::CreateSolver(solver_param));//从参数创建solver,同样采用string到函数指针的映射实现,用到了工厂模式 9 10 if (FLAGS_snapshot.size()) {//迭代snapshot次后保存模型一次 11 LOG(INFO) << "Resuming from " << FLAGS_snapshot; 12 solver->Restore(FLAGS_snapshot.c_str()); 13 } else if (FLAGS_weights.size()) {//若采用finetuning,则拷贝weight到指定模型 14 CopyLayers(solver.get(), FLAGS_weights); 15 } 16 17 if (gpus.size() > 1) { 18 caffe::P2PSync<float> sync(solver, NULL, solver->param()); 19 sync.Run(gpus); 20 } else { 21 LOG(INFO) << "Starting Optimization"; 22 solver->Solve();//开始训练网络 23 } 24 LOG(INFO) << "Optimization Done."; 25 return 0; 26 }
Solver()
看CreateSolver函数是如何构建solver和net的,CreateSolver定义在solver_factory.hpp中,首先需要知道的是solver是一个基类,继承自它的类有SGD等,下面的实现就可以根据param的type构造一个指向特定solver的指针,比如SGD。
1 static Solver<Dtype>* CreateSolver(const SolverParameter& param) { 2 const string& type = param.type(); 3 CreatorRegistry& registry = Registry(); 4 CHECK_EQ(registry.count(type), 1) << "Unknown solver type: " << type 5 << " (known types: " << SolverTypeListString() << ")"; 6 return registry[type](param); 7 }
关键之处在于上面代码最后一行语句,它的作用是根据配置文件创建对应的Solver对象(默认为SGDSolver子类对象)。此处工厂模式和一个关键的宏REGISTER_SOLVER_CLASS(SGD)发挥了重要作用。
1 #define REGISTER_SOLVER_CLASS(type) 2 template <typename Dtype> 3 Solver<Dtype>* Creator_##type##Solver( 4 const SolverParameter& param) 5 { 6 return new type##Solver<Dtype>(param); 7 } 8 REGISTER_SOLVER_CREATOR(type, Creator_##type##Solver) 9 }
这样一个SGDSolver对象就调用其构造函数被构造出来了。
1 explicit SGDSolver(const SolverParameter& param) 2 : Solver<Dtype>(param) { PreSolve(); }
同时,Solver这个基类也被构造出来了,在solver.hpp里
1 explicit Solver(const SolverParameter& param, 2 const Solver* root_solver = NULL);
Solver构造函数又会调用Init进行训练网络和测试网络的初始化,Init函数没有被声明为虚函数,不能被覆写,也就是说所有的solver都调用这个函数进行初始化。
1 template <typename Dtype> 2 void Solver<Dtype>::Init(const SolverParameter& param) { 3 ...... 4 // Scaffolding code 5 InitTrainNet();//初始化训练网络 6 if (Caffe::root_solver()) { 7 InitTestNets();//初始化测试网络 8 LOG(INFO) << "Solver scaffolding done."; 9 } 10 iter_ = 0;//迭代次数设为0 11 current_step_ = 0; 12 }
InitTrainNet()
接下来看训练网络初始化函数InitTrainNet,具体的内容见Net的网络层的构建(源码分析)
caffe是如何来solve的:在成员函数Solve()内部,
1 template <typename Dtype> 2 void Solver<Dtype>::Solve(const char* resume_file) { 3 ...... 4 // For a network that is trained by the solver, no bottom or top vecs 5 // should be given, and we will just provide dummy vecs. 6 int start_iter = iter_; 7 //开始迭代 8 Step(param_.max_iter() - iter_); 9 ...... 10 }
Step()
下面我们看一下Solver::Step()函数内部实现情况,具体的一次迭代过程。见Caffe参数交换源码分析
这就是整个网络的训练过程。
当神已无能为力,那便是魔渡众生