yolo v2 损失函数源码(训练核心代码)解读和其实现原理
2017-09-19 09:29
627 查看
前提说明:
1, 关于 yolo 和 yolo v2 的详细解释请移步至如下两个链接,或者直接看论文(我自己有想写 yolo 的教程,但思前想后下面两个链接中的文章质量实在是太好了_(:з」∠)_)
yolo: https://zhuanlan.zhihu.com/p/24916786?refer=xiaoleimlnote
yolo v2: https://zhuanlan.zhihu.com/p/25167153
2, 本文仅解读 yolo v2 的 loss 函数的源码,该代码请使用如下命令
git clone https://github.com/pjreddie/darknet
后打开 src/region_layer.c 查看
3, yolo 的官方网站地址为:https://pjreddie.com/darknet/yolo/
4, 我调试代码时使用的命令是:
./darknet detector train cfg/voc.data cfg/yolo-voc.cfg darknet19_448.conv.23
最新版yolo v2的损失函数的源码解读(解释无GPU版本),如下:
1, 关于 yolo 和 yolo v2 的详细解释请移步至如下两个链接,或者直接看论文(我自己有想写 yolo 的教程,但思前想后下面两个链接中的文章质量实在是太好了_(:з」∠)_)
yolo: https://zhuanlan.zhihu.com/p/24916786?refer=xiaoleimlnote
yolo v2: https://zhuanlan.zhihu.com/p/25167153
2, 本文仅解读 yolo v2 的 loss 函数的源码,该代码请使用如下命令
git clone https://github.com/pjreddie/darknet
后打开 src/region_layer.c 查看
3, yolo 的官方网站地址为:https://pjreddie.com/darknet/yolo/
4, 我调试代码时使用的命令是:
./darknet detector train cfg/voc.data cfg/yolo-voc.cfg darknet19_448.conv.23
最新版yolo v2的损失函数的源码解读(解释无GPU版本),如下:
void forward_region_layer(const region_layer l, network_state state) { int i,j,b,t,n; //size代表着每个box需要预测出来的参数。 int size = l.coords + l.classes + 1; memcpy(l.output, state.input, l.outputs*l.batch*sizeof(float)); #ifndef GPU flatten(l.output, l.w*l.h, size*l.n, l.batch, 1); #endif for (b = 0; b < l.batch; ++b){ for(i = 0; i < l.h*l.w*l.n; ++i){ int index = size*i + b*l.outputs; l.output[index + 4] = logistic_activate(l.output[index + 4]); } } #ifndef GPU if (l.softmax_tree){ for (b = 0; b < l.batch; ++b){ for(i = 0; i < l.h*l.w*l.n; ++i){ int index = size*i + b*l.outputs; softmax_tree(l.output + index + 5, 1, 0, 1, l.softmax_tree, l.output + index + 5); } } } else if (l.softmax){ for (b = 0; b < l.batch; ++b){ for(i = 0; i < l.h*l.w*l.n; ++i){ int index = size*i + b*l.outputs; softmax(l.output + index + 5, l.classes, 1, l.output + index + 5); } } } #endif if(!state.train) return; memset(l.delta, 0, l.outputs * l.batch * sizeof(float)); float avg_iou = 0; float recall = 0; float avg_cat = 0; float avg_obj = 0; float avg_anyobj = 0; int count = 0; int class_count = 0; *(l.cost) = 0; //这里是对批处理的所有图像进行前向求损失值。 for (b = 0; b < l.batch; ++b) { //没有使用这个softmax分类器,即不会进入这部分代码。 if(l.softmax_tree){ int onlyclass = 0; for(t = 0; t < 30; ++t){ box truth = float_to_box(state.truth + t*5 + b*l.truths); if(!truth.x) break; int class = state.truth[t*5 + b*l.truths + 4]; float maxp = 0; int maxi = 0; if(truth.x > 100000 && truth.y > 100000){ for(n = 0; n < l.n*l.w*l.h; ++n){ int index = size*n + b*l.outputs + 5; float scale = l.output[index-1]; float p = scale*get_hierarchy_probability(l.output + index, l.softmax_tree, class); if(p > maxp){ maxp = p; maxi = n; } } int index = size*maxi + b*l.outputs + 5; delta_region_class(l.output, l.delta, index, class, l.classes, l.softmax_tree, l.class_scale, &avg_cat); ++class_count; onlyclass = 1; break; } } if(onlyclass) continue; } /* 这里的l.h,l.w分别是最后卷积输出的特征图分辨率。l.n是anchor box的个数,这个机制是借鉴Faster R-CNN 的回归方法。l.n这个参数跟配置文件的anchors、num有关,值就是num一样。其跟V1版的不同,V1版的是不管最后输出 的特征图分辨率多少都是把起分成7*7个cell,而V2的每个特征点就是一个cell,优点就是:能回归和识别更小的物体。 */ for (j = 0; j < l.h; ++j) { for (i = 0; i < l.w; ++i) { //这个l.n是代表着特征点需要进行预测的不同尺寸的box个数,box宽高大小跟配置文件里的anchor系数有关。 for (n = 0; n < l.n; ++n) { int index = size*(j*l.w*l.n + i*l.n + n) + b*l.outputs; box pred = get_region_box(l.output, l.biases, n, index, i, j, l.w, l.h); float best_iou = 0; int best_class = -1; //这里是假设每个特征点cell最多只能有30个物体坐落在相同位置。其实这里的阈值影响不大的,其主要跟truth.x有关。 for(t = 0; t <30; ++t){ // get truth_box's x, y, w, h box truth = float_to_box(state.truth + t*5 + b*l.truths); // 遍历完图片中的所有物体后退出 if (!truth.x) break; float iou = box_iou(pred, truth); //选出iou最大那个框作为最后预测框~ if (iou > best_iou) { best_class = state.truth[t*5 + b*l.truths + 4]; best_iou = iou; } } //计算有没有目标的梯度 avg_anyobj += l.output[index + 4]; l.delta[index + 4] = l.noobject_scale * ((0 - l.output[index + 4]) * logistic_gradient(l.output[index + 4])); if(l.classfix == -1) l.delta[index + 4] = l.noobject_scale * ((best_iou - l.output[index + 4]) * logistic_gradient(l.output[index + 4])); else{ if (best_iou > l.thresh) { l.delta[index + 4] = 0; if(l.classfix > 0){ delta_region_class(l.output, l.delta, index + 5, best_class, l.classes, l.softmax_tree, l.class_scale*(l.classfix == 2 ? l.output[index + 4] : 1), &avg_cat); ++class_count; } } } //这里要训练的图片张数达到12800后能进入 if(*(state.net.seen) < 12800){ box truth = {0}; truth.x = (i + .5)/l.w; truth.y = (j + .5)/l.h; truth.w = l.biases[2*n]; truth.h = l.biases[2*n+1]; if(DOABS){ truth.w = l.biases[2*n]/l.w; truth.h = l.biases[2*n+1]/l.h; } // 将预测的 tx, ty, tw, th 和 实际box计算得出的 tx',ty', tw', th' 的差存入 l.delta delta_region_box(truth, l.output, l.biases, n, index, i, j, l.w, l.h, l.delta, .01); } } } } //运行到这步,则所有特征图上的所有格子都被标注,即代表有没有物体在此区域。 for(t = 0; t < 30; ++t){ // get truth_box's x, y, w, h box truth = float_to_box(state.truth + t*5 + b*l.truths); if(!truth.x) break; float best_iou = 0; int best_index = 0; int best_n = 0; i = (truth.x * l.w); j = (truth.y * l.h); //printf("%d %f %d %f\n", i, truth.x*l.w, j, truth.y*l.h); // 上面获得了 truth box 的 x,y,w,h,这里讲 truth box 的 x,y 偏移到 0,0,记 //为 truth_shift.x, truth_shift.y,这么做是为了方便计算 iou box truth_shift = truth; truth_shift.x = 0; truth_shift.y = 0; //printf("index %d %d\n",i, j); //这里是计算具有真实物体的地方与anchor boxs的匹配值。 for(n = 0; n < l.n; ++n){ //获得box的index。其中size是每个box需要计算的参数,(j*l.w*l.n + i*l.n + n)计算的是第几个格子 //b*l.outputs计算的是第几张输入图片的特征图,这样算就是为了计算位置。 int index = size*(j*l.w*l.n + i*l.n + n) + b*l.outputs; //获得box的预测,这里先是坐标位置x,y,w,h,而剩下的两个confidence放到后面, box pred = get_region_box(l.output, l.biases, n, index, i, j, l.w, l.h); //box的w,h是根据anchors生成的,其中l.biases就是配置文件里的那些anchors参数 if(l.bias_match){ pred.w = l.biases[2*n]; pred.h = l.biases[2*n+1]; if(DOABS){ pred.w = l.biases[2*n]/l.w; pred.h = l.biases[2*n+1]/l.h; } } //printf("pred: (%f, %f) %f x %f\n", pred.x, pred.y, pred.w, pred.h); //这里也把box位置移到0,0;这么做是为了方便计算IOU。 pred.x = 0; pred.y = 0; float iou = box_iou(pred, truth_shift); if (iou > best_iou){ best_index = index; best_iou = iou; best_n = n; } } //printf("%d %f (%f, %f) %f x %f\n", best_n, best_iou, truth.x, truth.y, truth.w, truth.h); // 计算 box 和 truth box 的 iou float iou = delta_region_box(truth, l.output, l.biases, best_n, best_index, i, j, l.w, l.h, l.delta, l.coord_scale); //如果大于阈值则召回率加1. if(iou > .5) recall += 1; avg_iou += iou; //运行到这里,位置的回归基本完成,下面主要是进行目标分类的操作 //l.delta[best_index + 4] = iou - l.output[best_index + 4]; avg_obj += l.output[best_index + 4]; //这里logistic_gradient把具有目标的区域进行逻辑回归分类,计算其输出的类别分数。 l.delta[best_index + 4] = l.object_scale * (1 - l.output[best_index + 4]) * logistic_gradient(l.output[best_index + 4]); if (l.rescore) { // 用 iou 代替上面的 1(经调试,l.rescore = 1,因此能走到这里) l.delta[best_index + 4] = l.object_scale * (iou - l.output[best_index + 4]) * logistic_gradient(l.output[best_index + 4]); } // 获得真实的 class int class = state.truth[t*5 + b*l.truths + 4]; if (l.map) class = l.map[class]; // 把所有 class 的预测概率与真实 class 的 0/1 的差 * scale,然后存入 l.delta 里相应 class 序号的位置 delta_region_class(l.output, l.delta, best_index + 5, class, l.classes, l.softmax_tree, l.class_scale, &avg_cat); ++count; ++class_count; } } //printf("\n"); #ifndef GPU flatten(l.delta, l.w*l.h, size*l.n, l.batch, 0); #endif // 现在,l.delta 中的每一个位置都存放了 class、confidence、x, y, w, h 的差,于是通过 mag_array 遍历所有位置,计算每个位置的平方的和后开根 // 然后利用 pow 函数求平方 *(l.cost) = pow(mag_array(l.delta, l.outputs * l.batch), 2); printf("Region Avg IOU: %f, Class: %f, Obj: %f, No Obj: %f, Avg Recall: %f, count: %d\n", avg_iou/count, avg_cat/class_count, avg_obj/count, avg_anyobj/(l.w*l.h*l.n*l.batch), recall/count, count); }注:上面的代码解释是个人参考网上资料后的一些见解,其中如有不对的地方,大家可以指出了,通过修改完善造福更多人。
相关文章推荐
- 微信云控系统的实现原理,微信云控系统源码之服务器推送的实现及其核心代码
- 微信群控系统的实现原理,微信群控系统源码的核心实现代码
- jQuery的实现原理的模拟代码 -1 核心部分
- 神经网络之激活函数 dropout原理解读 BatchNormalization 代码实现
- JDK源码之解读String最终类的trim()方法实现原理
- 微信群控系统源码的实现原理,核心源码实现,核心框架。
- QQ消息群发器实现原理及核心代码
- 从源码理解Spring原理,并用代码实现简易Spring框架
- JDK源码之解读String最终类的trim()方法实现原理
- ThreadPoolExecutor核心实现原理和源码解析<一>
- ConcurrentHashMap实现原理和源码解读
- MyBatis实现原理和代码解读
- Spring2.5源码解读 之 基于annotation的Controller实现原理分析(1)
- yolo v2 损失函数源码解读
- DQN 原理(三):DQN 训练代码实现
- 快速傅里叶变换(FFT)的原理、实现及代码解析(附C#源码)
- 解读Google官方SwipeRefreshLayout控件源码,带你揭秘Android下拉刷新的实现原理
- jQuery的实现原理的模拟代码 -1 核心部分
- Spring2.5源码解读 之 基于annotation的Controller实现原理分析(1)
- OkHttp源码解读总结(六)--->OkHttp拦截器核心代码总结