您的位置:首页 > 编程语言 > C语言/C++

学习笔记: 源码 softmax_loss_layer.cpp 略析

2017-05-16 21:47 399 查看
SoftmaxWithLossLayer

SoftmaxWithLossLayer 的功能相当于 SoftmaxLayer + MultinomialLogisticLossLayer。
但是直接使用SoftmaxWithLossLayer 会使得计算更方便,速度更快,而且计算精度也会提高。
关于解释,可以参考文末的链接。

SoftmaxWithLossLayer 主要也是三个部分,这里简单介绍一下,然后主要是对后向传播公式进行推导说明。

1. LayerSetUp()
template <typename Dtype>
void SoftmaxWithLossLayer<Dtype>::LayerSetUp(
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
...
softmax_layer_ = LayerRegistry<Dtype>::CreateLayer(softmax_param);  //内部创建一个softmax 层,用于计算概率 prob
...
}


2. forward()
template <typename Dtype>
void SoftmaxWithLossLayer<Dtype>::Forward_cpu(
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
...
loss -= log(std::max(prob_data[i * dim + label_value * inner_num_ + j],
Dtype(FLT_MIN)));
...}
Dtype normalizer = LossLayer<Dtype>::GetNormalizer(
normalization_, outer_num_, inner_num_, count);
top[0]->mutable_cpu_data()[0] = loss / normalizer;
...}


3. backward()
template <typename Dtype>
void SoftmaxWithLossLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
...
if (propagate_down[0]) {
...
caffe_copy(prob_.count(), prob_data, bottom_diff);   //先将prob 拷贝到 bottom_diff 中
...
for (int i = 0; i < outer_num_; ++i) {
for (int j = 0; j < inner_num_; ++j) {
const int label_value = static_cast<int>(label[i * inner_num_ + j]);
...
bottom_diff[i * dim + label_value * inner_num_ + j] -= 1;  //然后让label对应的 bottom_diff 减1
...
}
}
}
// Scale gradient
Dtype normalizer = LossLayer<Dtype>::GetNormalizer(
normalization_, outer_num_, inner_num_, count);
Dtype loss_weight = top[0]->cpu_diff()[0] / normalizer;
caffe_scal(prob_.count(), loss_weight, bottom_diff);  //最后乘上一个scale 系数
}
}


假设SoftmaxWithLossLayer 的输入为a, 输出为z 。有N个样本,K个类别。
前向的过程中,计算概率公式如下:



则输出的 loss 为:



则相应的梯度为:



参考链接:

http://freemind.pluskid.org/machine-learning/softmax-vs-softmax-loss-numerical-stability/
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: