Caffe源码阅读(2) 卷积层
2016-04-22 23:42
337 查看
http://ufldl.stanford.edu/tutorial/supervised/ConvolutionalNeuralNetwork/
http://cogprints.org/5869/1/cnn_tutorial.pdf
卷积层的参数的梯度可以这样来求:
∇W(l)kJ(W,b;x,y)∇b(l)kJ(W,b;x,y)=∑i=1m(a(l)i)∗rot90(δ(l+1)k,2),=∑a,b(δ(l+1)k)a,b.∇Wk(l)J(W,b;x,y)=∑i=1m(ai(l))∗rot90(δk(l+1),2),∇bk(l)J(W,b;x,y)=∑a,b(δk(l+1))a,b.
看上去比全连接层复杂多了,但其实,他们本质上基本是一样的,依然可以套回全连接层的参数求导公式:
∇W(l)J(W,b;x,y)∇b(l)J(W,b;x,y)=δ(l+1)(a(l))T,=δ(l+1).∇W(l)J(W,b;x,y)=δ(l+1)(a(l))T,∇b(l)J(W,b;x,y)=δ(l+1).
只需要额外增加一步
im2col。这一步的意思是将首先将整张图片按照卷积的窗口大小切好(按照stride来切,可以有重叠),然后各自拉成一列。
为啥要怎样做,因为对于这个小窗口内拉成一列的神经元来说来说,它们跟下一层神经元就是全连接了,所以这个小窗口里面的梯度计算就可以按照全连接来计算就可以了。
如果对照着Caffe的卷积层源码来看,就很清晰了。
forward的代码如下,假设没有分group,这段代码的意思是对于一个大小为num_的batch里面的任意一张图片,首先通过
im2col展开成多个列向量,之后直接就用wx+b的方式就能够算到输出了。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 | void ConvolutionLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) { for (int i = 0; i < bottom.size(); ++i) { const Dtype* bottom_data = bottom[i]->cpu_data(); Dtype* top_data = (*top)[i]->mutable_cpu_data(); Dtype* col_data = col_buffer_.mutable_cpu_data(); const Dtype* weight = this->blobs_[0]->cpu_data(); int weight_offset = M_ * K_; // number of filter parameters in a group int col_offset = K_ * N_; // number of values in an input region / column int top_offset = M_ * N_; // number of values in an output region / column for (int n = 0; n < num_; ++n) { // im2col transformation: unroll input regions for filtering // into column matrix for multplication. im2col_cpu(bottom_data + bottom[i]->offset(n), channels_, height_, width_, kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_, col_data); // Take inner products for groups. for (int g = 0; g < group_; ++g) { caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, K_, (Dtype)1., weight + weight_offset * g, col_data + col_offset * g, (Dtype)0., top_data + (*top)[i]->offset(n) + top_offset * g); } // Add bias. if (bias_term_) { caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num_output_, N_, 1, (Dtype)1., this->blobs_[1]->cpu_data(), bias_multiplier_.cpu_data(), (Dtype)1., top_data + (*top)[i]->offset(n)); } } } } |
言归正传,如果能够理解到
im2col的作用,那么backward的代码也很容易理解了。
对于bias,直接就是delta(可能还要乘以
bias_multiplier_,这个是Caffe自己的功能,默认不开启,即
bias_multiplier_=1)
1 2 3 4 5 6 7 8 9 10 | // Bias gradient, if necessary. if (bias_term_ && this->param_propagate_down_[1]) { top_diff = top[i]->cpu_diff(); for (int n = 0; n < num_; ++n) { caffe_cpu_gemv<Dtype>(CblasNoTrans, num_output_, N_, 1., top_diff + top[0]->offset(n), bias_multiplier_.cpu_data(), 1., bias_diff); } } |
im2col展开,之后用矩阵乘法表示累加:
1 2 3 4 5 6 7 8 9 1011 | // Since we saved memory in the forward pass by not storing all col // data, we will need to recompute them. im2col_cpu(bottom_data + (*bottom)[i]->offset(n), channels_, height_, width_, kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_, col_data); // gradient w.r.t. weight. Note that we will accumulate diffs. if (this->param_propagate_down_[0]) { for (int g = 0; g < group_; ++g) { caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasTrans, M_, K_, N_, (Dtype)1., top_diff + top[i]->offset(n) + top_offset * g, col_data + col_offset * g, (Dtype)1., weight_diff + weight_offset * g); } } |
col2im,其实也是一个累加的过程,让每个空间位置的delta累加起来。
1 2 3 4 5 6 7 8 9 1011 | // gradient w.r.t. bottom data, if necessary. if (propagate_down[i]) { if (weight == NULL) { weight = this->blobs_[0]->cpu_data(); } for (int g = 0; g < group_; ++g) { caffe_cpu_gemm<Dtype>(CblasTrans, CblasNoTrans, K_, N_, M_, (Dtype)1., weight + weight_offset * g, top_diff + top[i]->offset(n) + top_offset * g, (Dtype)0., col_diff + col_offset * g); } // col2im back to the data col2im_cpu(col_diff, channels_, height_, width_, kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_, bottom_diff + (*bottom)[i]->offset(n)); } |
相关文章推荐
- 基础DOM和CSS操作(一)
- html5的关于布局的研究(1)
- Caffe源码阅读(3)Softmax层和SoftmaxLoss层
- 剑指offer(二十四)之数组中出现次数超过一半的数字
- <css 十八>图片的透明
- BlogLife
- 前端笔记 CSS 5
- css设置网页占满屏幕
- Angularjs基础(七)
- 剑指offer:反转链表
- link和@import的区别
- js冒泡排序
- Web开发(二)--JSP
- 利用jquery写无缝循环滑动的轮播图
- django使用html模板减少代码
- css在线参考手册
- 第 2 章 排版样式
- Json(org.json)简单封装与解析
- 前端开发框架对比
- ajaxfileupload返回的json数据带<pre></pre>标签