您的位置:首页 > Web前端

caffe源码解析之Layer层(1)

2015-08-10 14:47 405 查看

前言

老实说,caffe中的layer层代码比较多,各种抽象看起来比较绕。官方关于Layer的教程写的很清楚,我根据这个文档,简单画了个图,再理解起来就方便了一些。



layer.hpp

和layer相关的头文件有:

common_layers.hpp
data_layers.hpp
layer.hpp
loss_layers.hpp
neuron_layers.hpp
vision_layers.hpp

其中``layer.hpp
是抽象出来的基类,其他都是在其基础上的继承,也即剩下的五个头文件和上图中的五个部分。在
layer.hpp`头文件里,包含了这几个头文件:

#include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/proto/caffe.pb.h"
#include "caffe/util/device_alternate.hpp"

device_alternate.hpp
中,通过
#ifdef
CPU_ONLY
定义了一些宏来取消GPU的调用:

#define STUB_GPU(classname)
#define STUB_GPU_FORWARD(classname, funcname)
#define STUB_GPU_BACKWARD(classname, funcname)

layer中有这三个主要参数:

LayerParameter layer_param_;      // 这个是protobuf文件中存储的layer参数
vector<share_ptr<Blob<Dtype>>> blobs_;        // 这个存储的是layer的参数,在程序中用的
vector<bool> param_propagate_down_;        // 这个bool表示是否计算各个blob参数的diff,即传播误差

Layer类的构建函数
explicit Layer(const LayerParameter& param) : layer_param_(param)
会尝试从protobuf文件读取参数。其三个主要接口:

virtual void SetUp(const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top)
inline Dtype Forward(const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top);
inline void Backward(const vector<Blob<Dtype>*>& top, const vector<bool>& propagate_down, const <Blob<Dtype>*>* bottom);

SetUp
函数需要根据实际的参数设置进行实现,对各种类型的参数初始化;
Forward
Backward
对应前向计算和反向更新,输入统一都是
bottom
,输出为
top
,其中
Backward
里面有个
propagate_down
参数,用来表示该Layer是否反向传播参数。

Forward
Backward
的具体实现里,会根据
Caffe::mode()
进行对应的操作,即使用cpu或者gpu进行计算,两个都实现了对应的接口
Forward_cpu
Forward_gpu
Backward_cpu
Backward_gpu
,这些接口都是
virtual
,具体还是要根据layer的类型进行对应的计算(注意:有些layer并没有GPU计算的实现,所以封装时加入了CPU的计算作为后备)。另外,还实现了
ToProto
的接口,将Layer的参数写入到protocol
buffer文件中。

data_layers.hpp

data_layers.hpp
这个头文件包含了这几个头文件:

#include "boost/scoped_ptr.hpp"
#include "hdf5.h"
#include "leveldb/db.h"
#include "lmdb.h"

#include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/filler.hpp"
#include "caffe/internal_thread.hpp"
#include "caffe/layer.hpp"
#include "caffe/proto/caffe.pb.h"

看到hdf5、leveldb、lmdb,确实是与具体数据相关了。
data_layer
作为原始数据的输入层,处于整个网络的最底层,它可以从数据库leveldb、lmdb中读取数据,也可以直接从内存中读取,还可以从hdf5,甚至是原始的图像读入数据。

关于这几个数据库,简介如下:

LevelDB是Google公司搞的一个高性能的key/value存储库,调用简单,数据是被Snappy压缩,据说效率很多,可以减少磁盘I/O,具体例子可以看看维基百科

LMDB(Lightning
Memory-Mapped Database),是个和levelDB类似的key/value存储库,但效果似乎更好些,其首页上写道“ultra-fast,ultra-compact”,这个有待进一步学习啊~~

HDF(Hierarchical Data Format)是一种为存储和处理大容量科学数据而设计的文件格式及相应的库文件,当前最流行的版本是HDF5,其文件包含两种基本数据对象:

群组(group):类似文件夹,可以包含多个数据集或下级群组;
数据集(dataset):数据内容,可以是多维数组,也可以是更复杂的数据类型。

以上内容来自维基百科,关于使用可以参考[HDF5
小试——高大上的多对象文件格式](HDF5 小试——高大上的多对象文件格式),后续会再详细的研究下怎么用。

caffe/filler.hpp
的作用是在网络初始化时,根据layer的定义进行初始参数的填充,下面的代码很直观,根据
FillerParameter
指定的类型进行对应的参数填充。

// A function to get a specific filler from the specification given in
// FillerParameter. Ideally this would be replaced by a factory pattern,
// but we will leave it this way for now.
template <typename Dtype>
Filler<Dtype>* GetFiller(const FillerParameter& param) {
const std::string& type = param.type();
if (type == "constant") {
return new ConstantFiller<Dtype>(param);
} else if (type == "gaussian") {
return new GaussianFiller<Dtype>(param);
} else if (type == "positive_unitball") {
return new PositiveUnitballFiller<Dtype>(param);
} else if (type == "uniform") {
return new UniformFiller<Dtype>(param);
} else if (type == "xavier") {
return new XavierFiller<Dtype>(param);
} else {
CHECK(false) << "Unknown filler name: " << param.type();
}
return (Filler<Dtype>*)(NULL);
}

internal_thread.hpp
里面封装了
pthread
函数,继承的子类可以得到一个单独的线程,主要作用是在计算当前的一批数据时,在后台获取新一批的数据。

关于
data_layer
,基本要注意的我都在图片上标注了。

neuron_layers.hpp

输入了data后,就要计算了,比如常见的
sigmoid
tanh
等等,这些都计算操作被抽象成了
neuron_layers.hpp
里面的类
NeuronLayer
,这个层只负责具体的计算,因此明确定义了输入
ExactNumBottomBlobs()
ExactNumTopBlobs()
都是常量1,即输入一个blob,输出一个blob。

common_layers.hpp

NeruonLayer
仅仅负责简单的一对一计算,而剩下的那些复杂的计算则通通放在了
common_layers.hpp
中。像
ArgMaxLayer
ConcatLayer
FlattenLayer
SoftmaxLayer
SplitLayer
SliceLayer
等各种对blob增减修改的操作。

loss_layers.hpp

前面的
data layer
common
layer
都是中间计算层,虽然会涉及到反向传播,但传播的源头来自于
loss_layer
,即网络的最终端。这一层因为要计算误差,所以输入都是2个blob,输出1个blob。

vision_layers.hpp

vision_layer
主要是图像卷积的操作,像convolusion、pooling、LRN都在里面,按官方文档的说法,是可以输出图像的,这个要看具体实现代码了。里面有个
im2col
的实现,看caffe作者的解释,主要是为了加速卷积的,这个具体是怎么实现的要好好研究下~~

后语

结合官方文档,再加画图和看代码,终于对整个
layer
层有了个基本认识:
data
负责输入,
vision
负责卷积相关的计算,
neuron
common
负责中间部分的数据计算,而
loss
是最后一部分,负责计算反向传播的误差。具体的实现都在
src/caffe/layers
里面,慢慢再研究研究。

在这些抽象的基类头文件里,看起来挺累,好在各种搜索,也能学到一些技巧,如,巧用宏定义来简写C,C++代码,使用模板方法,将有大量重复接口和参数的类抽象为一个宏定义,达到简化代码的目的。

Layer 功能:

是所有的网络层的基类,其中,定义了一些通用的接口,比如前馈,反馈,reshape,setup等。
<code class="hljs cpp has-numbering" style="display: block; padding: 0px; background-color: transparent; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-top-left-radius: 0px; border-top-right-radius: 0px; border-bottom-right-radius: 0px; border-bottom-left-radius: 0px; word-wrap: normal; background-position: initial initial; background-repeat: initial initial;"><span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">#ifndef CAFFE_LAYER_H_</span>
<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">#define CAFFE_LAYER_H_</span>

<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">#include <algorithm></span>
<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">#include <string></span>
<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">#include <vector></span>

<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">#include "caffe/blob.hpp"</span>
<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">#include "caffe/common.hpp"</span>
<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">#include "caffe/layer_factory.hpp"</span>
<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">#include "caffe/proto/caffe.pb.h"</span>
<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">#include "caffe/util/device_alternate.hpp"</span>

<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">namespace</span> caffe {
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// 功能:所有的网络层的基类,定义的所有的网络层的通用接口。</span>
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// 前馈接口,必须实现</span>
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// 反馈接口,需要的时候实现,计算梯度。</span>
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">template</span> <<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">typename</span> Dtype>
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">class</span> Layer {
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">public</span>:
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">/**
* 每个网络层需要自己定义它的setup而不需要构造函数
*/</span>
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">explicit</span> Layer(<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> LayerParameter& param)
: layer_param_(param) {
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">//通过网络层参数来构造网络层</span>
phase_ = param.phase();
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> (layer_param_.blobs_size() > <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>) {
blobs_.resize(layer_param_.blobs_size());
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">for</span> (<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">int</span> i = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>; i < layer_param_.blobs_size(); ++i) {
blobs_[i].reset(<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">new</span> Blob<Dtype>());
blobs_[i]->FromProto(layer_param_.blobs(i));
}
}
}
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// 析构函数</span>
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">virtual</span> ~Layer() {}

<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">/**
* 实现一些通用的设置功能
*
* @param bottom 网络层的输入的shape
* @param top 网络层的输出,需要被reshape
* 调用 LayerSetUp 来对每一个网络层进行特殊化的处理,
* 调用reshape top
* 设置 数值权重
* 这个方法可以不被重载。
*/</span>
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">void</span> SetUp(<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-stl_container" style="box-sizing: border-box;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">vector</span><Blob<Dtype></span>*>& bottom,
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-stl_container" style="box-sizing: border-box;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">vector</span><Blob<Dtype></span>*>& top) {
CheckBlobCounts(bottom, top);
LayerSetUp(bottom, top);
Reshape(bottom, top);
SetLossWeights(top);
}

<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">/**
* @brief 设置一些层相关的设置,定义的层需要实现这个方法以及Reshape方法
*/</span>
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">//设置网络层</span>
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">virtual</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">void</span> LayerSetUp(<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-stl_container" style="box-sizing: border-box;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">vector</span><Blob<Dtype></span>*>& bottom,
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-stl_container" style="box-sizing: border-box;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">vector</span><Blob<Dtype></span>*>& top) {}
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">/**
* @brief 调整top blob以适应bottom blob。
*/</span>
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">virtual</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">void</span> Reshape(<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-stl_container" style="box-sizing: border-box;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">vector</span><Blob<Dtype></span>*>& bottom,
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-stl_container" style="box-sizing: border-box;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">vector</span><Blob<Dtype></span>*>& top) = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>;

<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">/**
* @brief 给定 bottom blobs, 计算 top blobs 以及 loss.
* 每一个网络层都需要定义cpu版本的前馈,可选gpu版本的前馈
*/</span>
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">//前馈</span>
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">inline</span> Dtype Forward(<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-stl_container" style="box-sizing: border-box;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">vector</span><Blob<Dtype></span>*>& bottom,
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-stl_container" style="box-sizing: border-box;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">vector</span><Blob<Dtype></span>*>& top);

<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">/**
* @brief 给定 top blob 的梯度, 计算 bottom blob 梯度.
* @param propagate_down 向量,长度为ibottom   的个数,每个索引值表示是是否将损失梯度值反馈到该bottom中
*/</span>
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">//反馈</span>
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">inline</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">void</span> Backward(<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-stl_container" style="box-sizing: border-box;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">vector</span><Blob<Dtype></span>*>& top,
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-stl_container" style="box-sizing: border-box;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">vector</span><<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">bool</span>></span>& propagate_down,
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-stl_container" style="box-sizing: border-box;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">vector</span><Blob<Dtype></span>*>& bottom);

<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">/**
* @brief 返回可学习的参数 blobs.
*/</span>
<span class="hljs-stl_container" style="box-sizing: border-box;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">vector</span><<span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">shared_ptr</span><Blob<Dtype></span> > >& blobs() {
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">return</span> blobs_;
}

<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">/**
* @brief 返回网络层参数
*/</span>
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> LayerParameter& layer_param() <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> { <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">return</span> layer_param_; }
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">//序列化</span>
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">virtual</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">void</span> ToProto(LayerParameter* param, <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">bool</span> write_diff = <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">false</span>);

<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">/**
* @brief 返回指定索引的标量损失值。
*/</span>
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">inline</span> Dtype loss(<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">int</span> top_index) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> {
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">return</span> (loss_.size() > top_index) ? loss_[top_index] : Dtype(<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>);
}

<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">/**
* @brief 设置网络层制定索引位置的loss
*/</span>
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">inline</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">void</span> set_loss(<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">int</span> top_index, <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> Dtype value) {
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> (loss_.size() <= top_index) {
loss_.resize(top_index + <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>, Dtype(<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>));
}
loss_[top_index] = value;
}

<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">/**
* @brief 返回网络层名字,字符串描述u
*/</span>
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">virtual</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">inline</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">char</span>* type() <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> { <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">return</span> <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">""</span>; }
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">//Bottom的blob的确切数目</span>
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">virtual</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">inline</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">int</span> ExactNumBottomBlobs() <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> { <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">return</span> -<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>; }
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">//Bottom blob的最小数目</span>
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">virtual</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">inline</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">int</span> MinBottomBlobs() <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> { <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">return</span> -<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>; }
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">//Botttom的确切数目</span>
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">virtual</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">inline</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">int</span> MaxBottomBlobs() <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> { <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">return</span> -<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>; }
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">//Top Blob的确切数目</span>
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">virtual</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">inline</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">int</span> ExactNumTopBlobs() <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> { <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">return</span> -<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>; }
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">//最小的blob的数目</span>
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">virtual</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">inline</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">int</span> MinTopBlobs() <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> { <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">return</span> -<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>; }
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// 最大的blob的数目</span>
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">virtual</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">inline</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">int</span> MaxTopBlobs() <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> { <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">return</span> -<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>; }
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// 是否bottom 和top的数目相同</span>
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">virtual</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">inline</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">bool</span> EqualNumBottomTopBlobs() <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> { <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">return</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">false</span>; }
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// 是否自动Top blob</span>
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">virtual</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">inline</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">bool</span> AutoTopBlobs() <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> { <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">return</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">false</span>; }
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">//查询某一个bottom是否强制bp</span>
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">virtual</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">inline</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">bool</span> AllowForceBackward(<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">int</span> bottom_index) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> {
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">return</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">true</span>;
}
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">//查询某一个blob是否bp</span>
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">inline</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">bool</span> param_propagate_down(<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">int</span> param_id) {
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">return</span> (param_propagate_down_.size() > param_id) ?
param_propagate_down_[param_id] : <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">false</span>;
}
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">//设置某一个blob是否bp。</span>
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">inline</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">void</span> set_param_propagate_down(<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">int</span> param_id, <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">bool</span> value) {
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> (param_propagate_down_.size() <= param_id) {
param_propagate_down_.resize(param_id + <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>, <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">true</span>);
}
param_propagate_down_[param_id] = value;
}

<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">protected</span>:
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// 网络层参数</span>
LayerParameter layer_param_;
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// 模式</span>
Phase phase_;
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">//用blob来存储一系列向量</span>
<span class="hljs-stl_container" style="box-sizing: border-box;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">vector</span><<span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">shared_ptr</span><Blob<Dtype></span> > > blobs_;
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">//是否bp的向量</span>
<span class="hljs-stl_container" style="box-sizing: border-box;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">vector</span><<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">bool</span>></span> param_propagate_down_;
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">//存储top的loss</span>
<span class="hljs-stl_container" style="box-sizing: border-box;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">vector</span><Dtype></span> loss_;
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">//cpu版本的前馈实现</span>
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">virtual</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">void</span> Forward_cpu(<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-stl_container" style="box-sizing: border-box;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">vector</span><Blob<Dtype></span>*>& bottom,
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-stl_container" style="box-sizing: border-box;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">vector</span><Blob<Dtype></span>*>& top) = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>;
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">//gpu版本的前馈实现</span>
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">virtual</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">void</span> Forward_gpu(<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-stl_container" style="box-sizing: border-box;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">vector</span><Blob<Dtype></span>*>& bottom,
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-stl_container" style="box-sizing: border-box;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">vector</span><Blob<Dtype></span>*>& top) {
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// LOG(WARNING) << "Using CPU code as backup.";</span>
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">return</span> Forward_cpu(bottom, top);
}
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">//cpu版本的前馈实现</span>
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">virtual</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">void</span> Backward_cpu(<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-stl_container" style="box-sizing: border-box;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">vector</span><Blob<Dtype></span>*>& top,
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-stl_container" style="box-sizing: border-box;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">vector</span><<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">bool</span>></span>& propagate_down,
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-stl_container" style="box-sizing: border-box;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">vector</span><Blob<Dtype></span>*>& bottom) = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>;
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">//gpu版本的反馈实现</span>
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">virtual</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">void</span> Backward_gpu(<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-stl_container" style="box-sizing: border-box;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">vector</span><Blob<Dtype></span>*>& top,
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-stl_container" style="box-sizing: border-box;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">vector</span><<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">bool</span>></span>& propagate_down,
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-stl_container" style="box-sizing: border-box;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">vector</span><Blob<Dtype></span>*>& bottom) {
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// LOG(WARNING) << "Using CPU code as backup.";</span>
Backward_cpu(top, propagate_down, bottom);
}
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// 核查bootom和top的大小是否与该layer层指定的一致。</span>
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">virtual</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">void</span> CheckBlobCounts(<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-stl_container" style="box-sizing: border-box;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">vector</span><Blob<Dtype></span>*>& bottom,
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-stl_container" style="box-sizing: border-box;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">vector</span><Blob<Dtype></span>*>& top) {
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> (ExactNumBottomBlobs() >= <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>) {
CHECK_EQ(ExactNumBottomBlobs(), bottom.size())
<< type() << <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">" Layer takes "</span> << ExactNumBottomBlobs()
<< <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">" bottom blob(s) as input."</span>;
}
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> (MinBottomBlobs() >= <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>) {
CHECK_LE(MinBottomBlobs(), bottom.size())
<< type() << <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">" Layer takes at least "</span> << MinBottomBlobs()
<< <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">" bottom blob(s) as input."</span>;
}
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> (MaxBottomBlobs() >= <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>) {
CHECK_GE(MaxBottomBlobs(), bottom.size())
<< type() << <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">" Layer takes at most "</span> << MaxBottomBlobs()
<< <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">" bottom blob(s) as input."</span>;
}
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> (ExactNumTopBlobs() >= <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>) {
CHECK_EQ(ExactNumTopBlobs(), top.size())
<< type() << <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">" Layer produces "</span> << ExactNumTopBlobs()
<< <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">" top blob(s) as output."</span>;
}
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> (MinTopBlobs() >= <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>) {
CHECK_LE(MinTopBlobs(), top.size())
<< type() << <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">" Layer produces at least "</span> << MinTopBlobs()
<< <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">" top blob(s) as output."</span>;
}
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> (MaxTopBlobs() >= <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>) {
CHECK_GE(MaxTopBlobs(), top.size())
<< type() << <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">" Layer produces at most "</span> << MaxTopBlobs()
<< <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">" top blob(s) as output."</span>;
}
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> (EqualNumBottomTopBlobs()) {
CHECK_EQ(bottom.size(), top.size())
<< type() << <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">" Layer produces one top blob as output for each "</span>
<< <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">"bottom blob input."</span>;
}
}
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// 用blob初始化损失权重。</span>
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">inline</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">void</span> SetLossWeights(<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-stl_container" style="box-sizing: border-box;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">vector</span><Blob<Dtype></span>*>& top) {
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">int</span> num_loss_weights = layer_param_.loss_weight_size();
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> (num_loss_weights) {
CHECK_EQ(top.size(), num_loss_weights) << <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">"loss_weight must be "</span>
<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">"unspecified or specified once per top blob."</span>;
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">for</span> (<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">int</span> top_id = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>; top_id < top.size(); ++top_id) {
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> Dtype loss_weight = layer_param_.loss_weight(top_id);
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> (loss_weight == Dtype(<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>)) { <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">continue</span>; }
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">this</span>->set_loss(top_id, loss_weight);
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">int</span> count = top[top_id]->count();
Dtype* loss_multiplier = top[top_id]->mutable_cpu_diff();
caffe_set(count, loss_weight, loss_multiplier);
}
}
}

DISABLE_COPY_AND_ASSIGN(Layer);
};  <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// class Layer</span>

<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// 前馈,根据caffe的mode 调用相对应的cpu实现或者是gpu实现,并且计算损失函数值。</span>
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">template</span> <<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">typename</span> Dtype>
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">inline</span> Dtype Layer<Dtype>::Forward(<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-stl_container" style="box-sizing: border-box;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">vector</span><Blob<Dtype></span>*>& bottom,
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-stl_container" style="box-sizing: border-box;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">vector</span><Blob<Dtype></span>*>& top) {
Dtype loss = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>;
Reshape(bottom, top);
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">switch</span> (Caffe::mode()) {
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">case</span> Caffe::CPU:
Forward_cpu(bottom, top);
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">for</span> (<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">int</span> top_id = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>; top_id < top.size(); ++top_id) {
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> (!<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">this</span>->loss(top_id)) { <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">continue</span>; }
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">int</span> count = top[top_id]->count();
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> Dtype* data = top[top_id]->cpu_data();
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> Dtype* loss_weights = top[top_id]->cpu_diff();
loss += caffe_cpu_dot(count, data, loss_weights);
}
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">break</span>;
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">case</span> Caffe::GPU:
Forward_gpu(bottom, top);
<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">#ifndef CPU_ONLY</span>
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">for</span> (<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">int</span> top_id = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>; top_id < top.size(); ++top_id) {
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> (!<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">this</span>->loss(top_id)) { <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">continue</span>; }
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">int</span> count = top[top_id]->count();
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> Dtype* data = top[top_id]->gpu_data();
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> Dtype* loss_weights = top[top_id]->gpu_diff();
Dtype blob_loss = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>;
caffe_gpu_dot(count, data, loss_weights, &blob_loss);
loss += blob_loss;
}
<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">#endif</span>
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">break</span>;
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">default</span>:
LOG(FATAL) << <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">"Unknown caffe mode."</span>;
}
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">return</span> loss;
}

<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">//反向传播梯度,根据Caffe的mode是在GPU还是CPU,调用相对应版本的函数</span>
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">//propagate_down 用于控制对应的层是否bp</span>
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">template</span> <<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">typename</span> Dtype>
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">inline</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">void</span> Layer<Dtype>::Backward(<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-stl_container" style="box-sizing: border-box;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">vector</span><Blob<Dtype></span>*>& top,
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-stl_container" style="box-sizing: border-box;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">vector</span><<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">bool</span>></span>& propagate_down,
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-stl_container" style="box-sizing: border-box;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">vector</span><Blob<Dtype></span>*>& bottom) {
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">switch</span> (Caffe::mode()) {
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">case</span> Caffe::CPU:
Backward_cpu(top, propagate_down, bottom);
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">break</span>;
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">case</span> Caffe::GPU:
Backward_gpu(top, propagate_down, bottom);
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">break</span>;
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">default</span>:
LOG(FATAL) << <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">"Unknown caffe mode."</span>;
}
}

<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// 序列化网络层参数到协议缓存,最终是调用blob写入协议缓存。</span>
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">template</span> <<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">typename</span> Dtype>
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">void</span> Layer<Dtype>::ToProto(LayerParameter* param, <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">bool</span> write_diff) {
param->Clear();
param->CopyFrom(layer_param_);
param->clear_blobs();
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">for</span> (<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">int</span> i = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>; i < blobs_.size(); ++i) {
blobs_[i]->ToProto(param->add_blobs(), write_diff);
}
}

}  <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// namespace caffe</span>

<span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;">#endif  <span class="hljs-comment" style="box-sizing: border-box;">// CAFFE_LAYER_H_</span></span></code>
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: