您的位置:首页 > 运维架构

YOLO源码详解(四)- 反向传播(back propagation)

2016-12-26 11:51 337 查看
本系列作者:木凌 

时间:2016年12月。 

文章连接:http://blog.csdn.net/u014540717

反向传播是CNN中非常重要的一个环节,对于理论部分,这里不做介绍,如果对反向传播理论部分不熟悉,可以查看以下网站。 

非常详细:零基础入门深度学习(3) - 神经网络和反向传播算法 

非常详细:零基础入门深度学习(4) - 卷积神经网络 

非常生动:如何直观的解释back propagation算法? 

通过以上理论部分的学习,如果你还是感觉一脸蒙逼,那就看YOLO的代码吧,看完源代码你就会豁然开朗。让我们来一睹“back propagation”芳容


一、主函数backward_network(network net, network_state state)

//network.c
void backward_network(network net, network_state state)
{
int i;
float *original_input = state.input;
float *original_delta = state.delta;
state.workspace = net.workspace;
for(i = net.n-1; i >= 0; --i){
state.index = i;
if(i == 0){
state.input = original_input;
state.delta = original_delta;
}else{
layer prev = net.layers[i-1];
state.input = prev.output;
//这里注意,因为delta是指针变量,对state.delta做修改,就相当与对prev层的delta做了修改
state.delta = prev.delta;
}
layer l = net.layers[i];
l.backward(l, state);
}
}
//这函数没什么好说的,一层一层看吧,顺序如下:
//[detection]
//[connected]
//[dropout]
//[local]
//[convolutional]
//[maxpool]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29

1、反向传播-detection层
//detection_layer.c
void backward_detection_layer(const detection_layer l, network_state state)
{
//给state.delta赋值,l.delta存放的是预测值与真实值的差
axpy_cpu(l.batch*l.inputs, 1, l.delta, 1, state.delta, 1);
}
//blas.c
//axpy函数:y += a * x
void axpy_cpu(int N, float ALPHA, float *X, int INCX, float *Y, int INCY)
{
int i;
for(i = 0; i < N; ++i) Y[i*INCY] += ALPHA*X[i*INCX];
}
1
2
3
4
5
6
7
8
9
10
11
12
13
1
2
3
4
5
6
7
8
9
10
11
12
13

2、反向传播-connected层
//connected_layer.c
void backward_connected_layer(connected_layer l, network_state state)
{
int i;
//计算激活层的梯度值
gradient_array(l.output, l.outputs*l.batch, l.activation, l.delta);
//把batch size每个样本对应的值加起来,放入bias_updates指向的内存
for(i = 0; i < l.batch; ++i){
axpy_cpu(l.outputs, 1, l.delta + i*l.outputs, 1, l.bias_updates, 1);
}
//全链接层没用到batch_normalize,这里不做介绍
if(l.batch_normalize){
backward_scale_cpu(l.x_norm, l.delta, l.batch, l.outputs, 1, l.scale_updates);

scale_bias(l.delta, l.scales, l.batch, l.outputs, 1);

mean_delta_cpu(l.delta, l.variance, l.batch, l.outputs, 1, l.mean_delta);
variance_delta_cpu(l.x, l.delta, l.mean, l.variance, l.batch, l.outputs, 1, l.variance_delta);
normalize_delta_cpu(l.x, l.mean, l.variance, l.mean_delta, l.variance_delta, l.batch, l.outputs, 1, l.delta);
}

int m = l.outputs;
int k = l.batch;
int n = l.inputs;
float *a = l.delta;
float *b = state.input;
float *c = l.weight_updates;
//更新这一层的权重值
gemm(1,0,m,n,k,1,a,m,b,n,1,c,n);

m = l.batch;
k = l.outputs;
n = l.inputs;

a = l.delta;
b = l.weights;
c = state.delta;
//更新前一(prev)层的误差项
if(c) gemm(0,0,m,n,k,1,a,k,b,n,1,c,n);
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40

3、反向传播-dropout层
//dropout_layer.c
void backward_dropout_layer(dropout_layer l, network_state state)
{
int i;
if(!state.delta) return;
for(i = 0; i < l.batch * l.inputs; ++i){
//l.rand[i]就是0~1之间的随机数,这在前向传播的时候有讲
float r = l.rand[i];
//同样将前一层的delta赋值为0
if(r < l.probability) state.delta[i] = 0;
else state.delta[i] *= l.scale;
}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
1
2
3
4
5
6
7
8
9
10
11
12
13

4、反向传播-local层
//local_layer.c
void backward_local_layer(local_layer l, network_state state)
{
int i, j;
int locations = l.out_w*l.out_h;
//计算激活层梯度
gradient_array(l.output, l.outputs*l.batch, l.activation, l.delta);
//跟新bias_updates
for(i = 0; i < l.batch; ++i){
axpy_cpu(l.outputs, 1, l.delta + i*l.outputs, 1, l.bias_updates, 1);
}

for(i = 0; i < l.batch; ++i){
float *input = state.input + i*l.w*l.h*l.c;
im2col_cpu(input, l.c, l.h, l.w,
l.size, l.stride, l.pad, l.col_image);

for(j = 0; j < locations; ++j){
float *a = l.delta + i*l.outputs + j;
float *b = l.col_image + j;
float *c = l.weight_updates + j*l.size*l.size*l.c*l.n;
int m = l.n;
int n = l.size*l.size*l.c;
int k = 1;
//更新权重
gemm(0,1,m,n,k,1,a,locations,b,locations,1,c,n);
}

if(state.delta){
for(j = 0; j < locations; ++j){
float *a = l.weights + j*l.size*l.size*l.c*l.n;
float *b = l.delta + i*l.outputs + j;
float *c = l.col_image + j;

int m = l.size*l.size*l.c;
int n = 1;
int k = l.n;
//更新下一层误差项
gemm(1,0,m,n,k,1,a,m,b,locations,0,c,locations);
}

col2im_cpu(l.col_image, l.c,  l.h,  l.w,  l.size,  l.stride, l.pad, state.delta+i*l.c*l.h*l.w);
}
}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45

5、反向传播-convolutional层
//convolutional_layer.c
//跟local层一样~
void backward_convolutional_layer(convolutional_layer l, network_state state)
{
int i;
int m = l.n;
int n = l.size*l.size*l.c;
int k = convolutional_out_height(l)*
convolutional_out_width(l);

gradient_array(l.output, m*k*l.batch, l.activation, l.delta);
backward_bias(l.bias_updates, l.delta, l.batch, l.n, k);

for(i = 0; i < l.batch; ++i){
float *a = l.delta + i*m*k;
float *b = state.workspace;
float *c = l.weight_updates;

float *im = state.input+i*l.c*l.h*l.w;

im2col_cpu(im, l.c, l.h, l.w,
l.size, l.stride, l.pad, b);
gemm(0,1,m,n,k,1,a,k,b,k,1,c,n);

if(state.delta){
a = l.weights;
b = l.delta + i*m*k;
c = state.workspace;

gemm(1,0,n,k,m,1,a,n,b,k,0,c,k);

col2im_cpu(state.workspace, l.c,  l.h,  l.w,  l.size,  l.stride, l.pad, state.delta+i*l.c*l.h*l.w);
}
}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35

6、反向传播-maxpool层
//maxpool_layer.c
void backward_maxpool_layer(const maxpool_layer l, network_state state)
{
int i;
int h = l.out_h;
int w = l.out_w;
int c = l.c;
for(i = 0; i < h*w*c*l.batch; ++i){
//l.indexes存储的是前一层最大值的坐标
int index = l.indexes[i];
state.delta[index] += l.delta[i];
}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
1
2
3
4
5
6
7
8
9
10
11
12
13




0
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: