您的位置:首页 > 理论基础 > 计算机网络

(Caffe,LeNet)网络训练流程(二)

2017-07-28 09:13 465 查看
版权声明:未经允许请勿用于商业用途,转载请注明出处:http://blog.csdn.net/mounty_fsc/

目录(?)[+]
程序入口
Solver的创建
SolverSolve函数
SolverStep函数
1 SolverTestAll函数
2 NetForwardBackward函数
3 SolverApplyUpdate函数

训练完毕

本文地址:http://blog.csdn.net/mounty_fsc/article/details/51090114

在训练lenet的
train_lenet.sh
中内容为:

./build/tools/caffe train –solver=examples/mnist/lenet_solver.prototxt

由此可知,训练网咯模型是由
tools/caffe.cpp
生成的工具
caffe
在模式
train
下完成的。

初始化过程总的来说,从
main()
train()
中创建
Solver
,在
Solver
中创建
Net
,在
Net
中创建
Layer
.

1 程序入口

找到
caffe.cpp
main
函数中,通过
GetBrewFunction(caffe::string(argv[1]))()
调用执行
train()
函数。
train中
,通过参数
-examples/mnist/lenet_solver.prototxt
solver
参数读入
solver_param
中。

随后注册并定义
solver
的指针(见第2节)

shared_ptr<caffe::Solver<float> >
solver(caffe::SolverRegistry<float>::CreateSolver(solver_param))
1
2


1
2

调用
solver
Solver()
方法。多个GPU涉及到GPU间带异步处理问题(见第3节)

if (gpus.size() > 1) {
caffe::P2PSync<float> sync(solver, NULL, solver->param());
sync.run(gpus);
} else {
LOG(INFO) << "Starting Optimization";
solver->Solve();
}
1
2
3
4
5
6
7


1
2
3
4
5
6
7

2 Solver的创建

在1中,
Solver
的指针
solver
是通过
SolverRegistry::CreateSolver
创建的,
CreateSolver
函数中值得注意带是
return registry[type](param)


// Get a solver using a SolverParameter.
static Solver<Dtype>* CreateSolver(const SolverParameter& param) {
const string& type = param.type();
CreatorRegistry& registry = Registry();
CHECK_EQ(registry.count(type), 1) << "Unknown solver type: " << type
<< " (known types: " << SolverTypeListString() << ")";
return registry[type](param);
}
1
2
3
4
5
6
7
8


1
2
3
4
5
6
7
8
其中:

registry
是一个
map<string,Creator>: typedef std::map<string, Creator> CreatorRegistry


其中
Creator
是一个函数指针类型:
typedef Solver<Dtype>* (*Creator)(const SolverParameter&)


registry[type]
为一个函数指针变量,在
Lenet5
中,此处具体的值为
caffe::Creator_SGDSolver<float>(caffe::SolverParameter const&)


其中
Creator_SGDSolver
在以下宏中定义,
REGISTER_SOLVER_CLASS(SGD)


该宏完全展开得到的内容为:

template <typename Dtype>                                                    \
Solver<Dtype>* Creator_SGDSolver(                                       \
const SolverParameter& param)                                            \
{                                                                            \
return new SGDSolver<Dtype>(param);                                     \
}                                                                            \
static SolverRegisterer<float> g_creator_f_SGD("SGD", Creator_SGDSolver<float>);    \
static SolverRegisterer<double> g_creator_d_SGD("SGD", Creator_SGDSolver<double>)
1
2
3
4
5
6
7
8


1
2
3
4
5
6
7
8
从上可以看出,
registry[type](param)
中实际上调用了
SGDSolver
带构造方法,事实上,网络是在
SGDSolver
的构造方法中初始化的。

SGDSolver
的定义如下:

template <typename Dtype>
class SGDSolver : public Solver<Dtype> {
public:
explicit SGDSolver(const SolverParameter& param)
: Solver<Dtype>(param) { PreSolve(); }
explicit SGDSolver(const string& param_file)
: Solver<Dtype>(param_file) { PreSolve(); }
......
1
2
3
4
5
6
7
8


1
2
3
4
5
6
7
8
SGDSolver
继承与
Solver<Dtype>
,因而
new SGDSolver<Dtype>(param)
将执行
Sol
b22c
ver<Dtype>
的构造函数,然后调用自身构造函数。整个网络带初始化即在这里面完成(详见本系列博文(三))。

3 Solver::Solve()函数

在这个函数里面,程序执行完网络的完整训练过程。

核心代码如下:

template <typename Dtype>
void Solver<Dtype>::Solve(const char* resume_file) {

Step(param_.max_iter() - iter_);
//..
Snapshot();
//..

// some additional display
// ...
}
1
2
3
4
5
6
7
8
9
10
11


1
2
3
4
5
6
7
8
9
10
11
说明:

值得关注的代码是
Step()
,在该函数中,值得了
param_.max_iter()
轮迭代(10000)
在Snapshot()中序列化model到文件

4 Solver::Step()函数

template <typename Dtype>
void Solver<Dtype>::Step(int iters) {

//10000轮迭代
while (iter_ < stop_iter) {

// 每隔500轮进行一次测试
if (param_.test_interval() && iter_ % param_.test_interval() == 0
&& (iter_ > 0 || param_.test_initialization())
&& Caffe::root_solver()) {
// 测试网络,实际是执行前向传播计算loss
TestAll();
}

// accumulate the loss and gradient
Dtype loss = 0;
for (int i = 0; i < param_.iter_size(); ++i) {
// 执行反向传播,前向计算损失loss,并计算loss关于权值的偏导
loss += net_->ForwardBackward(bottom_vec);
}

// 平滑loss,计算结果用于输出调试等
loss /= param_.iter_size();
// average the loss across iterations for smoothed reporting
UpdateSmoothedLoss(loss, start_iter, average_loss);

// 通过反向传播计算的偏导更新权值
ApplyUpdate();

}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31

4.1 Solver::TestAll()函数

TestAll()
中,调用
Test(test_net_id)
对每个测试网络test_net(不是训练网络train_net)进行测试。在Lenet中,只有一个测试网络,所以只调用一次
Test(0)
进行测试。

Test()函数里面做了两件事:

前向计算网络,得到网络损失,见 (Caffe,LeNet)前向计算(五)

通过测试网络的第11层accuracy层,与第12层loss层结果统计accuracy与loss信息。

4.2 Net::ForwardBackward()函数

Dtype ForwardBackward(const vector<Blob<Dtype>* > & bottom) {
Dtype loss;
Forward(bottom, &loss);
Backward();
return loss;
}
1
2
3
4
5
6


1
2
3
4
5
6
说明:

前向计算。计算网络损失loss,参考 (Caffe,LeNet)前向计算(五)
反向传播。计算loss关于网络权值的偏导,参考 (Caffe,LeNet)反向传播(六)

4.3 Solver::ApplyUpdate()函数

根据反向传播阶段计算的loss关于网络权值的偏导,使用配置的学习策略,更新网络权值从而完成本轮学习。详见 (Caffe,LeNet)权值更新(七)

5 训练完毕

至此,网络训练优化完成。在第3部分solve()函数中,最后对训练网络与测试网络再执行一轮额外的前行计算求得loss,以进行测试。
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签:  Caffe