caffe的Solver介绍 以及 max_iter迭代的具体代码实现
2016-09-22 10:29
447 查看
Original url:
http://alanse7en.github.io/caffedai-ma-jie-xi-4/
本文将主要分为四部分的内容:
caffe.cpp中的train函数中通过上面的代码定义了一个指向
Create solver
下面我们就来具体看一下
首先需要注意的是这个类的构造函数是private的,也就是用我们没有办法去构造一个这个类型的变量,这个类也没有数据成员,所有的成员函数也都是static的,可以直接调用。
我们首先从
上面的代码中,
Register Solver
下面我们具体来看一下Solver的register的过程:
在sgd_solver.cpp(SGD Solver对应的cpp文件)末尾有上面第24行的代码,使用了
Caffe在train或者test的过程中都有可能会遇到系统信号(用户按下ctrl+c或者关掉了控制的terminal),我们可以通过对
caffe train –solver=/path/to/solver.prototxt –sigint_effect=EFFECT –sighup_effect=EFFECT
在caffe.cpp中定义了一个GetRequesedAction函数来将设置的string类型的标志转变为枚举类型的变量:
其中SolverAction::Enum的定义在solver.hpp中,这是一个定义为枚举类型的数据类型,只有三个可能的值,分别对应了三种处理系统信号的方式:NONE(忽略信号什么都不做)/STOP(停止训练)/SNAPSHOT(保存当前的训练状态,继续训练)。在caffe.cpp中的train函数里Solver设置如何处理系统信号的代码为:
FLAGS_sigint_effect和FLAGS_sighup_effect是通过gflags定义和解析的两个Command Line Interface的输入参数,分别对应遇到sigint和sighup信号的处理方式,如果用户不设定(大部分时候我自己就没设定),sigint的默认值为”stop”,sighup的默认值为”snapshot”。
= &handle_signal来设置,当有遇到系统信号时,调用
在根据用户设置(或者默认值)的参数定义了signal_handler之后,solver通过
总结起来,我们通过定义一个
下面继续分析具体的迭代过程发生的
每一组网络中的参数的更新都是在不同类型的Solver自己实现的
下面我们继续具体分析一下
至此
http://alanse7en.github.io/caffedai-ma-jie-xi-4/
本文将主要分为四部分的内容:
Solver的初始化(Register宏和构造函数)
SIGINT和
SIGHUP信号的处理
Solver::Solve()具体实现
SGDSolver::ApplyUpdate具体实现
Solver的初始化(Register宏和构造函数)
shared_ptr<caffe::Solver<float> > solver(caffe::SolverRegistry<float>::CreateSolver(solver_param));
caffe.cpp中的train函数中通过上面的代码定义了一个指向
Solver<float>的shared_ptr。其中主要是通过调用
SolverRegistry这个类的静态成员函数
CreateSolver得到一个指向
Solver的指针来构造shared_ptr类型的
solver。而且由于C++多态的特性,尽管
solver是一个指向基类
Solver类型的指针,通过
solver这个智能指针来调用各个成员函数会调用到各个子类(
SGDSolver等)的函数。具体的过程如下面的流程图所示:
Create solver
下面我们就来具体看一下
SolverRegistry这个类的代码,以便理解是如何通过同一个函数得到不同类型的Solver:
1 class SolverRegistry { 2 public: 3 typedef Solver<Dtype>* (*Creator)(const SolverParameter&); 4 typedef std::map<string, Creator> CreatorRegistry; 5 static CreatorRegistry& Registry() { 6 static CreatorRegistry* g_registry_ = new CreatorRegistry(); 7 return *g_registry_; 8 } 9 static void AddCreator(const string& type, Creator creator) { 10 CreatorRegistry& registry = Registry(); 11 CHECK_EQ(registry.count(type), 0) 12 << "Solver type " << type << " already registered."; 13 registry[type] = creator; 14 } 15 static Solver<Dtype>* CreateSolver(const SolverParameter& param) { 16 const string& type = param.type(); 17 CreatorRegistry& registry = Registry(); 18 CHECK_EQ(registry.count(type), 1) << "Unknown solver type: " << type 19 << " (known types: " << SolverTypeListString() << ")"; 20 return registry[type](param); 21 } 22 static vector<string> SolverTypeList() { 23 CreatorRegistry& registry = Registry(); 24 vector<string> solver_types; 25 for (typename CreatorRegistry::iterator iter = registry.begin(); 26 iter != registry.end(); ++iter) { 27 solver_types.push_back(iter->first); 28 } 29 return solver_types; 30 } 31 private: 32 SolverRegistry() {} 33 static string SolverTypeListString() { 34 vector<string> solver_types = SolverTypeList(); 35 string solver_types_str; 36 for (vector<string>::iterator iter = solver_types.begin(); 37 iter != solver_types.end(); ++iter) { 38 if (iter != solver_types.begin()) { 39 solver_types_str += ", "; 40 } 41 solver_types_str += *iter; 42 } 43 return solver_types_str; 44 } 45 };
首先需要注意的是这个类的构造函数是private的,也就是用我们没有办法去构造一个这个类型的变量,这个类也没有数据成员,所有的成员函数也都是static的,可以直接调用。
我们首先从
CreateSolver函数(第15行)入手,这个函数先定义了string类型的变量type,表示Solver的类型(‘SGD’/’Nestrov’等),然后定义了一个key类型为string,value类型为
Creator的map:registry,其中
Creator是一个函数指针类型,指向的函数的参数为
SolverParameter类型,返回类型为
Solver<Dtype>*(见第2行和第3行)。如果是一个已经register过的Solver类型,那么
registry.count(type)应该为1,然后通过registry这个map返回了我们需要类型的Solver的creator,并调用这个creator函数,将creator返回的
Solver<Dtype>*返回。
上面的代码中,
Registry这个函数(第5行)中定义了一个static的变量g_registry,这个变量是一个指向
CreatorRegistry这个map类型的指针,然后直接返回,因为这个变量是static的,所以即使多次调用这个函数,也只会定义一个g_registry,而且在其他地方修改这个map里的内容,是存储在这个map中的。事实上各个Solver的register的过程正是往g_registry指向的那个map里添加以Solver的type为key,对应的Creator函数指针为value的内容。Register的过程如流程图所示:
Register Solver
下面我们具体来看一下Solver的register的过程:
1 template <typename Dtype> 2 class SolverRegisterer { 3 public: 4 SolverRegisterer(const string& type, 5 Solver<Dtype>* (*creator)(const SolverParameter&)) { 6 // LOG(INFO) << "Registering solver type: " << type; 7 SolverRegistry<Dtype>::AddCreator(type, creator); 8 } 9 }; 10 #define REGISTER_SOLVER_CREATOR(type, creator) \ 11 static SolverRegisterer<float> g_creator_f_##type(#type, creator<float>); \ 12 static SolverRegisterer<double> g_creator_d_##type(#type, creator<double>) \ 13 14 #define REGISTER_SOLVER_CLASS(type) \ 15 template <typename Dtype> \ 16 Solver<Dtype>* Creator_##type##Solver( \ 17 const SolverParameter& param) \ 18 { \ 19 return new type##Solver<Dtype>(param); \ 20 } \ 21 REGISTER_SOLVER_CREATOR(type, Creator_##type##Solver) 22 } 23 // register SGD Solver 24 REGISTER_SOLVER_CLASS(SGD);
在sgd_solver.cpp(SGD Solver对应的cpp文件)末尾有上面第24行的代码,使用了
REGISTER_SOLVER_CLASS这个宏,这个宏会定义一个名为
Creator_SGDSolver的函数,这个函数即为
Creator类型的指针指向的函数,在这个函数中调用了
SGDSolver的构造函数,并将构造的这个变量得到的指针返回,这也就是Creator类型函数的作用:构造一个对应类型的Solver对象,将其指针返回。然后在这个宏里又调用了
REGISTER_SOLVER_CREATOR这个宏,这里分别定义了
SolverRegisterer这个模板类的float和double类型的static变量,这会去调用各自的构造函数,而在
SolverRegisterer的构造函数中调用了之前提到的
SolverRegistry类的
AddCreator函数,这个函数就是将刚才定义的
Creator_SGDSolver这个函数的指针存到g_registry指向的map里面。类似地,所有的Solver对应的cpp文件的末尾都调用了这个宏来完成注册,在所有的Solver都注册之后,我们就可以通过之前描述的方式,通过g_registry得到对应的Creator函数的指针,并通过调用这个Creator函数来构造对应的Solver。Register和Create对应的流程图如下所示:
SIGINT
和SIGHUP
信号的处理
Caffe在train或者test的过程中都有可能会遇到系统信号(用户按下ctrl+c或者关掉了控制的terminal),我们可以通过对sigint_effect和
sighup_effect来设置遇到系统信号的时候希望进行的处理方式:
caffe train –solver=/path/to/solver.prototxt –sigint_effect=EFFECT –sighup_effect=EFFECT
在caffe.cpp中定义了一个GetRequesedAction函数来将设置的string类型的标志转变为枚举类型的变量:
1 caffe::SolverAction::Enum GetRequestedAction( 2 const std::string& flag_value) { 3 if (flag_value == "stop") { 4 return caffe::SolverAction::STOP; 5 } 6 if (flag_value == "snapshot") { 7 return caffe::SolverAction::SNAPSHOT; 8 } 9 if (flag_value == "none") { 10 return caffe::SolverAction::NONE; 11 } 12 LOG(FATAL) << "Invalid signal effect \""<< flag_value << "\" was specified"; 13 } 14 // SolverAction::Enum的定义 15 namespace SolverAction { 16 enum Enum { 17 NONE = 0, // Take no special action. 18 STOP = 1, // Stop training. snapshot_after_train controls whether a 19 // snapshot is created. 20 SNAPSHOT = 2 // Take a snapshot, and keep training. 21 }; 22 }
其中SolverAction::Enum的定义在solver.hpp中,这是一个定义为枚举类型的数据类型,只有三个可能的值,分别对应了三种处理系统信号的方式:NONE(忽略信号什么都不做)/STOP(停止训练)/SNAPSHOT(保存当前的训练状态,继续训练)。在caffe.cpp中的train函数里Solver设置如何处理系统信号的代码为:
1 caffe::SignalHandler signal_handler( 2 GetRequestedAction(FLAGS_sigint_effect), 3 GetRequestedAction(FLAGS_sighup_effect)); 4 5 solver->SetActionFunction(signal_handler.GetActionFunction());
FLAGS_sigint_effect和FLAGS_sighup_effect是通过gflags定义和解析的两个Command Line Interface的输入参数,分别对应遇到sigint和sighup信号的处理方式,如果用户不设定(大部分时候我自己就没设定),sigint的默认值为”stop”,sighup的默认值为”snapshot”。
GetRequestedAction函数会将string类型的FLAGS_xx转为SolverAction::Enum类型,并用来定义一个
SignalHandler类型的对象signal_handler。我们可以看到这部分代码都依赖于
SignalHandler这个类的接口,我们先来看看这个类都做了些什么:
1 // header file 2 class SignalHandler { 3 public: 4 // Contructor. Specify what action to take when a signal is received. 5 SignalHandler(SolverAction::Enum SIGINT_action, 6 SolverAction::Enum SIGHUP_action); 7 ~SignalHandler(); 8 ActionCallback GetActionFunction(); 9 private: 10 SolverAction::Enum CheckForSignals() const; 11 SolverAction::Enum SIGINT_action_; 12 SolverAction::Enum SIGHUP_action_; 13 }; 14 // source file 15 SignalHandler::SignalHandler(SolverAction::Enum SIGINT_action, 16 SolverAction::Enum SIGHUP_action): 17 SIGINT_action_(SIGINT_action), 18 SIGHUP_action_(SIGHUP_action) { 19 HookupHandler(); 20 } 21 void HookupHandler() { 22 if (already_hooked_up) { 23 LOG(FATAL) << "Tried to hookup signal handlers more than once."; 24 } 25 already_hooked_up = true; 26 struct sigaction sa; 27 sa.sa_handler = &handle_signal; 28 // ... 29 } 30 static volatile sig_atomic_t got_sigint = false; 31 static volatile sig_atomic_t got_sighup = false; 32 void handle_signal(int signal) { 33 switch (signal) { 34 case SIGHUP: 35 got_sighup = true; 36 break; 37 case SIGINT: 38 got_sigint = true; 39 break; 40 } 41 } 42 ActionCallback SignalHandler::GetActionFunction() { 43 return boost::bind(&SignalHandler::CheckForSignals, this); 44 } 45 SolverAction::Enum SignalHandler::CheckForSignals() const { 46 if (GotSIGHUP()) { 47 return SIGHUP_action_; 48 } 49 if (GotSIGINT()) { 50 return SIGINT_action_; 51 } 52 return SolverAction::NONE; 53 } 54 bool GotSIGINT() { 55 bool result = got_sigint; 56 got_sigint = false; 57 return result; 58 } 59 bool GotSIGHUP() { 60 bool result = got_sighup; 61 got_sighup = false; 62 return result; 63 } 64 // ActionCallback的含义 65 typedef boost::function<SolverAction::Enum()> ActionCallback;
SignalHandler这个类有两个数据成员,都是
SolverAction::Enum类型的,分别对应sigint和sighup信号,在构造函数中,用解析FLAGS_xx得到的结果分别给两个成员赋值,然后调用了
HookupHandler函数,这个函数的主要作用是定义了一个
sigaction类型(应该是系统级别的代码)的对象sa,然后通过sa.sa_handler
= &handle_signal来设置,当有遇到系统信号时,调用
handle_signal函数来处理,而我们可以看到这个函数的处理很简单,就是判断一下当前的信号是什么类型,如果是sigint就将全局的static变量got_sigint变为true,sighup的处理类似。
在根据用户设置(或者默认值)的参数定义了signal_handler之后,solver通过
SetActionFunction来设置了如何处理系统信号。这个函数的输入为signal_handler的
GetActionFunction的返回值,根据上面的代码我们可以看到,
GetActionFunction会返回signal_handler这个对象的CheckForSignals函数的地址(boost::bind的具体使用请参考boost官方文档)。而在
Solver的
SetActionFunction函数中只是简单的把
Solver的一个成员action_request_function_赋值为输入参数的值,以当前的例子来说就是,solver对象的action_request_function_指向了signal_handler对象的CheckForSignals函数的地址。其中的ActionCallback是一个函数指针类型,指向了参数为空,返回值为SolverAction::Enum类型的函数(boost::function具体用法参考官方文档)。
总结起来,我们通过定义一个
SignalHandler类型的对象,告知系统在遇到系统信号的时候回调
handle_signal函数来改变全局变量got_sigint和got_sighup的值,然后通过
Solver的接口设置了其遇到系统函数将调用signal_handler的Check函数,这个函数实际上就是去判断当前是否遇到了系统信号,如果遇到某个类型的信号,就返回我们之前设置的处理方式(
SolverAction::Enum类型)。剩余的具体处理再交给
Solver的其它函数,后面会具体分析。
Solver::Solve()
具体实现
Solve函数实现了具体的网络的优化过程,下面我们来具体分析一下这部分的代码,分析见注释:
1 void Solver<Dtype>::Solve(const char* resume_file) { 2 // 检查当前是否是root_solver(多GPU模式下,只有root_solver才运行这一部分的代码) 3 CHECK(Caffe::root_solver()); 4 // 然后输出learning policy(更新学习率的策略) 5 LOG(INFO) << "Solving " << net_->name(); 6 LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy(); 7 // requested_early_exit_`一开始被赋值为false,也就是现在没有要求在优化结束前退出 8 requested_early_exit_ = false; 9 // 判断`resume_file`这个指针是否NULL,如果不是则需要从resume_file存储的路径里读取之前训练的状态 10 if (resume_file) { 11 LOG(INFO) << "Restoring previous solver status from " << resume_file; 12 Restore(resume_file); 13 } 14 // 然后调用了'Step'函数,这个函数执行了实际的逐步的迭代过程 15 Step(param_.max_iter() - iter_); 16 // 迭代结束或者遇到系统信号提前结束后,判断是否需要在训练结束之后snapshot 17 // 这个可以在solver.prototxt里设置 18 if (param_.snapshot_after_train() 19 && (!param_.snapshot() || iter_ % param_.snapshot() != 0)) { 20 Snapshot(); 21 } 22 // 如果在`Step`函数的迭代过程中遇到了系统信号,且我们的处理方式设置为`STOP`, 23 // 那么`requested_early_exit_`会被修改为true,迭代提前结束,输出相关信息 24 if (requested_early_exit_) { 25 LOG(INFO) << "Optimization stopped early."; 26 return; 27 } 28 // 判断是否需要输出最后的loss 29 if (param_.display() && iter_ % param_.display() == 0) { 30 Dtype loss; 31 net_->ForwardPrefilled(&loss); 32 LOG(INFO) << "Iteration " << iter_ << ", loss = " << loss; 33 } 34 // 判断是否需要最后Test 35 if (param_.test_interval() && iter_ % param_.test_interval() == 0) { 36 TestAll(); 37 } 38 LOG(INFO) << "Optimization Done."; 39 }
下面继续分析具体的迭代过程发生的
Step函数:
1 template <typename Dtype> 2 void Solver<Dtype>::Step(int iters) { 3 vector<Blob<Dtype>*> bottom_vec; 4 // 设置开始的迭代次数(如果是从之前的snapshot恢复的,那iter_等于snapshot时的迭代次数)和结束的迭代次数 5 const int start_iter = iter_; 6 const int stop_iter = iter_ + iters; 7 // 输出的loss为前average_loss次loss的平均值,在solver.prototxt里设置,默认为1, 8 // losses存储之前的average_loss个loss,smoothed_loss为最后要输出的均值 9 int average_loss = this->param_.average_loss(); 10 vector<Dtype> losses; 11 Dtype smoothed_loss = 0; 12 // 迭代 13 while (iter_ < stop_iter) { 14 // 清空上一次所有参数的梯度 15 net_->ClearParamDiffs(); 16 // 判断是否需要测试 17 if (param_.test_interval() && iter_ % param_.test_interval() == 0 18 && (iter_ > 0 || param_.test_initialization()) 19 && Caffe::root_solver()) { 20 TestAll(); 21 // 判断是否需要提前结束迭代 22 if (requested_early_exit_) { 23 break; 24 } 25 } 26 for (int i = 0; i < callbacks_.size(); ++i) { 27 callbacks_[i]->on_start(); 28 } 29 // 判断当前迭代次数是否需要显示loss等信息 30 const bool display = param_.display() && iter_ % param_.display() == 0; 31 net_->set_debug_info(display && param_.debug_info()); 32 Dtype loss = 0; 33 // iter_size也是在solver.prototxt里设置,实际上的batch_size=iter_size*网络定义里的batch_size, 34 // 因此每一次迭代的loss是iter_size次迭代的和,再除以iter_size,这个loss是通过调用`Net::ForwardBackward`函数得到的 35 // 这个设置我的理解是在GPU的显存不够的时候使用,比如我本来想把batch_size设置为128,但是会out_of_memory, 36 // 借助这个方法,可以设置batch_size=32,iter_size=4,那实际上每次迭代还是处理了128个数据 37 for (int i = 0; i < param_.iter_size(); ++i) { 38 loss += net_->ForwardBackward(bottom_vec); 39 } 40 loss /= param_.iter_size(); 41 // 计算要输出的smoothed_loss,如果losses里还没有存够average_loss个loss则将当前的loss插入,如果已经存够了,则将之前的替换掉 42 if (losses.size() < average_loss) { 43 losses.push_back(loss); 44 int size = losses.size(); 45 smoothed_loss = (smoothed_loss * (size - 1) + loss) / size; 46 } else { 47 int idx = (iter_ - start_iter) % average_loss; 48 smoothed_loss += (loss - losses[idx]) / average_loss; 49 losses[idx] = loss; 50 } 51 // 输出当前迭代的信息 52 if (display) { 53 LOG_IF(INFO, Caffe::root_solver()) << "Iteration " << iter_ 54 << ", loss = " << smoothed_loss; 55 const vector<Blob<Dtype>*>& result = net_->output_blobs(); 56 int score_index = 0; 57 for (int j = 0; j < result.size(); ++j) { 58 const Dtype* result_vec = result[j]->cpu_data(); 59 const string& output_name = 60 net_->blob_names()[net_->output_blob_indices()[j]]; 61 const Dtype loss_weight = 62 net_->blob_loss_weights()[net_->output_blob_indices()[j]]; 63 for (int k = 0; k < result[j]->count(); ++k) { 64 ostringstream loss_msg_stream; 65 if (loss_weight) { 66 loss_msg_stream << " (* " << loss_weight 67 << " = " << loss_weight * result_vec[k] << " loss)"; 68 } 69 LOG_IF(INFO, Caffe::root_solver()) << " Train net output #" 70 << score_index++ << ": " << output_name << " = " 71 << result_vec[k] << loss_msg_stream.str(); 72 } 73 } 74 } 75 for (int i = 0; i < callbacks_.size(); ++i) { 76 callbacks_[i]->on_gradients_ready(); 77 } 78 // 执行梯度的更新,这个函数在基类`Solver`中没有实现,会调用每个子类自己的实现,后面具体分析`SGDSolver`的实现 79 ApplyUpdate(); 80 // 迭代次数加1 81 ++iter_; 82 // 调用GetRequestedAction,实际是通过action_request_function_函数指针调用之前设置好(通过`SetRequestedAction`)的 83 // signal_handler的`CheckForSignals`函数,这个函数的作用是 84 // 会根据之前是否遇到系统信号以及信号的类型和我们设置(或者默认)的方式返回处理的方式 85 SolverAction::Enum request = GetRequestedAction(); 86 // 判断当前迭代是否需要snapshot,如果request等于`SNAPSHOT`则也需要 87 if ((param_.snapshot() 88 && iter_ % param_.snapshot() == 0 89 && Caffe::root_solver()) || 90 (request == SolverAction::SNAPSHOT)) { 91 Snapshot(); 92 } 93 // 如果request为`STOP`则修改`requested_early_exit_`为true,之后就会提前结束迭代 94 if (SolverAction::STOP == request) { 95 requested_early_exit_ = true; 96 break; 97 } 98 } 99 }
SGDSolver::ApplyUpdate
具体实现
每一组网络中的参数的更新都是在不同类型的Solver自己实现的ApplyUpdate函数中完成的,下面我们就以最常用的SGD为例子来分析这个函数具体的功能:
1 template <typename Dtype> 2 void SGDSolver<Dtype>::ApplyUpdate() { 3 CHECK(Caffe::root_solver()); 4 // GetLearningRate根据设置的lr_policy来计算当前迭代的learning rate的值 5 Dtype rate = GetLearningRate(); 6 // 判断是否需要输出当前的learning rate 7 if (this->param_.display() && this->iter_ % this->param_.display() == 0) { 8 LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate; 9 } 10 // 避免梯度爆炸,如果梯度的二范数超过了某个数值则进行scale操作,将梯度减小 11 ClipGradients(); 12 // 对所有可更新的网络参数进行操作 13 for (int param_id = 0; param_id < this->net_->learnable_params().size(); 14 ++param_id) { 15 // 将第param_id个参数的梯度除以iter_size,这一步的作用是保证实际的batch_size=iter_size*设置的batch_size 16 Normalize(param_id); 17 // 将正则化部分的梯度降入到每个参数的梯度中 18 Regularize(param_id); 19 // 计算SGD算法的梯度(momentum等) 20 ComputeUpdateValue(param_id, rate); 21 } 22 // 调用`Net::Update`更新所有的参数 23 this->net_->Update(); 24 }
下面我们继续具体分析一下
Normalize/
Regularize/
ComputeUpdateValue的实现,我们均以CPU的代码为例子,GPU部分的处理原理是一样的:
Normalize
1 template <typename Dtype> 2 void SGDSolver<Dtype>::Normalize(int param_id) { 3 // 如果iter_size的值为1,则不需要任何处理直接return 4 if (this->param_.iter_size() == 1) { return; } 5 // 通过net_返回所有可以学习的参数,是一个vector<shared_ptr<Blob<Dtype> > > 6 const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params(); 7 // 要乘以的系数等于1/iter_size 8 const Dtype accum_normalization = Dtype(1.) / this->param_.iter_size(); 9 switch (Caffe::mode()) { 10 case Caffe::CPU: { 11 // caffe_scal在/CAFFE_ROOT/src/caffe/util/math_functions.cpp中 12 // 是blas的scale函数的一个封装,第一个参数是数据的个数,第二个参数是乘以的系数, 13 // 第三个参数是数据的指针 14 caffe_scal(net_params[param_id]->count(), accum_normalization, 15 net_params[param_id]->mutable_cpu_diff()); 16 break; 17 } 18 case Caffe::GPU: { 19 // GPU代码略 20 } 21 }
Regularize
1 template <typename Dtype> 2 void SGDSolver<Dtype>::Regularize(int param_id) { 3 // 获取所有可以学习的参数的vector 4 const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params(); 5 // 获取所有的参数对应的weight_decay的vector 6 const vector<float>& net_params_weight_decay = 7 this->net_->params_weight_decay(); 8 // 模型整体的weight_decay数值 9 Dtype weight_decay = this->param_.weight_decay(); 10 // 获取正则化的类型:L1 或 L2 11 string regularization_type = this->param_.regularization_type(); 12 // 实际的weight_decay等于整体模型的数值乘以具体每个参数的数值 13 Dtype local_decay = weight_decay * net_params_weight_decay[param_id]; 14 switch (Caffe::mode()) { 15 case Caffe::CPU: { 16 // 如果weight_decay不为0,则计算 17 if (local_decay) { 18 if (regularization_type == "L2") { 19 // L2的梯度为diff_ = weight_decay*data_ + diff_ 20 // caffe_axpy的功能是 y = a*x + y 21 // 第一个参数是数据的个数,第二个是上式的a,第三个是x的指针,第四个是y的指针 22 caffe_axpy(net_params[param_id]->count(), 23 local_decay, 24 net_params[param_id]->cpu_data(), 25 net_params[param_id]->mutable_cpu_diff()); 26 } else if (regularization_type == "L1") { 27 // L1的梯度为diff_ = diff_ + sign(data_) 28 // temp_ = sign(data_) 29 caffe_cpu_sign(net_params[param_id]->count(), 30 net_params[param_id]->cpu_data(), 31 temp_[param_id]->mutable_cpu_data()); 32 // 将temp_加到diff_中 diff_ = weight_decay*temp_ + diff_ 33 caffe_axpy(net_params[param_id]->count(), 34 local_decay, 35 temp_[param_id]->cpu_data(), 36 net_params[param_id]->mutable_cpu_diff()); 37 } else { 38 LOG(FATAL) << "Unknown regularization type: " << regularization_type; 39 } 40 } 41 break; 42 } 43 // GPU代码略 44 }
ComputeUpdatedValue
1 template <typename Dtype> 2 void SGDSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) { 3 // 获取所有可以更新的参数的vector 4 const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params(); 5 // 获取所有参数对应的learning_rate的vector 6 const vector<float>& net_params_lr = this->net_->params_lr(); 7 // 获取momentum数值 8 Dtype momentum = this->param_.momentum(); 9 // 实际的learning_rate为全局的learning_rate乘以每个参数对应的learning_rate 10 Dtype local_rate = rate * net_params_lr[param_id]; 11 switch (Caffe::mode()) { 12 case Caffe::CPU: { 13 // 关于SGD的公式参考caffe官网tutorial的Solver部分 14 // history_存储了上一次的梯度,下面这个函数: 15 // history_ = learning_rate*diff_ + momentum*history 16 caffe_cpu_axpby(net_params[param_id]->count(), local_rate, 17 net_params[param_id]->cpu_diff(), momentum, 18 history_[param_id]->mutable_cpu_data()); 19 // 把当前的梯度拷贝给参数Blob的diff_ 20 caffe_copy(net_params[param_id]->count(), 21 history_[param_id]->cpu_data(), 22 net_params[param_id]->mutable_cpu_diff()); 23 break; 24 } 25 case Caffe::GPU: { 26 // GPU代码略 27 } 28 }
至此
Solver主要的代码都已经分析完了,总结起来主要有:(1)solver_factory的register和create不同类型Solver的机制,(2)通过signal_handler来获取系统信号,并根据用户或默认的设置进行相应的处理,(3)
Solver::Solve函数的具体实现的分析,(4)
SGDSolver::ApplyUpdate函数的具体实现。前面三个部分都属于基类的,最后一个是SGDSolver这个子类的,如果用户想要实现自己的Solver类,也应该类似地去继承基类,并实现自己的
ApplyUpdate函数,在代码的末尾通过register宏完成注册,便可以被成功的调用。
相关文章推荐
- 第六周(1) 后台代码编写与客户端具体功能实现以及界面优化
- caffe代码阅读3:data_reader、internalthread以及blocking_queue的实现细节-2016.3.15
- iOS实现地图定位(具体实现代码以及注释详解)
- Hilbert曲线介绍以及代码实现
- Android平台Camera实时滤镜实现方法探讨(十)--代码地址以及简单介绍(20160118更新)
- 线程的同步异步,以及具体代码实现,使用场景
- caffe代码阅读6:SyncedMemory的j介绍与实现
- 红黑树的介绍以及代码实现(C++)
- Java实现MD5加密以及解密类,附带测试类,具体见代码。
- ID3、C4.5算法介绍以及java代码实现
- 循环队列的判断满、空的三种方法以及具体代码实现(数组实现)
- caffe代码阅读4:LayerRegistry的介绍与实现
- caffe中的solver.protxt的test_iter以及test_interval的区别
- caffe代码阅读2:DataTransformer以及io的实现细节
- caffe代码阅读1:Layer的介绍与实现细节
- caffe代码阅读4:DataTransformer以及io的实现细节-2016.3.16
- 原始LBP纹理特征提取方法介绍以及代码实现
- [置顶] 树:哈夫曼树和哈夫曼编码的详细介绍以及代码实现
- 本文给出了一种方便实用的解决大文件的读取、存储等处理的方法,并结合相关程序代码对具体的实现过程进行了介绍
- Android jni aes加解密,实现文件的加解密,具体实现可以自行修改,上面的代码为简单介绍,下面的是JNI端实现文件加解密,可以修改为字符串加解密