您的位置:首页 > 编程语言

OpenCV HaarTraining代码解析(二)cvCreateMTStumpClassifier(建立决策树)

2016-01-08 08:40 459 查看
HaarTraining关键的部分是建立基分类器classifier,OpenCV中所採用的是CART(决策树的一种):通过调用cvCreateMTStumpClassifier来完毕。

这里我讨论利用回归的方法来分裂结点。分类的方法仅仅是在分裂结点的方法与之不同而已。

cvCreateMTStumpClassifier

//设置决策树分类误差计算方法
stumperror = (int) ((CvMTStumpTrainParams*) trainParams)->error;

//设置class step和ydata
ydata = trainClasses->data.ptr;
if( trainClasses->rows == 1 )
{
m = trainClasses->cols;
ystep = CV_ELEM_SIZE( trainClasses->type );
}
else
{
m = trainClasses->rows;
ystep = trainClasses->step;
}
//设置weight step和wdata
wdata = weights->data.ptr;
if( weights->rows == 1 )
{
assert( weights->cols == m );
wstep = CV_ELEM_SIZE( weights->type );
}
else
{
assert( weights->rows == m );
wstep = weights->step;
}

//设置步长,地址等參数,用于获取idxCache内容
if( ((CvMTStumpTrainParams*) trainParams)->sortedIdx != NULL )
{
sortedtype =
CV_MAT_TYPE( ((CvMTStumpTrainParams*) trainParams)->sortedIdx->type );
assert( sortedtype == CV_16SC1 || sortedtype == CV_32SC1
|| sortedtype == CV_32FC1 );
sorteddata = ((CvMTStumpTrainParams*) trainParams)->sortedIdx->data.ptr;
sortedsstep = CV_ELEM_SIZE( sortedtype );
sortedcstep = ((CvMTStumpTrainParams*) trainParams)->sortedIdx->step;
sortedn = ((CvMTStumpTrainParams*) trainParams)->sortedIdx->rows;
sortedm = ((CvMTStumpTrainParams*) trainParams)->sortedIdx->cols;
}

if( trainData == NULL )
{
assert( ((CvMTStumpTrainParams*) trainParams)->getTrainData != NULL );
n = ((CvMTStumpTrainParams*) trainParams)->numcomp;
assert( n > 0 );
}
//设置步长,地址等參数,用于获取dataCache内容
else
{
assert( CV_MAT_TYPE( trainData->type ) == CV_32FC1 );
data = trainData->data.ptr;
if( CV_IS_ROW_SAMPLE( flags ) )
{
cstep = CV_ELEM_SIZE( trainData->type );
sstep = trainData->step;
assert( m == trainData->rows );
datan = n = trainData->cols;
}
else
{
sstep = CV_ELEM_SIZE( trainData->type );
cstep = trainData->step;
assert( m == trainData->cols );
datan = n = trainData->rows;
}
if( ((CvMTStumpTrainParams*) trainParams)->getTrainData != NULL )
{
n = ((CvMTStumpTrainParams*) trainParams)->numcomp;
}
}
可能研究代码到这里的朋友仍然不清楚idxCache和valCache的作用。

我这里做一点简单的说明:

valCache是设置在训练前有多少特征值被提前算出存放在内存中。idxCache是valCache中每种特征按特征值从小到大排列的样本的序号。

内存大小是通过执行程序的命令行參数设置的。在cvHaartraining.cpp中我们能够找到这句话,当中float和short,各自是valCache和idxCache存放内容的基本类型。

//1MB == 1048576B  计算一个样本中有多少个特征能被pre计算放在内存中
numprecalculated = (int) ( ((size_t) mem) * ((size_t) 1048576) /
( ((size_t) (npos + nneg)) * (sizeof( float ) + sizeof( short )) ) );


为了方便理解,我把两者的内存模型画了出来





要注意idxCache中每行的index排列是示意图。

比方第一行代表feature1从小到大的index顺序。从图中能够看出。sample1的特征值feature1 < sample0的特征值feature1<...<sample n < sample n-1。

利用idxCache数组我们能够方便按特征值的从小到大遍历valCache,而且节省了空间。从float->short。

理解了这两个cache之后我们再回到上面的代码,能够发现这里做得仅仅是设置步长和cache首地址的一些操作。为以下開始的遍历做好准备。

跳过一些变量的初始化步骤,我们来到构建决策树stump的部分,而且为了方便阅读核心代码。去掉了其它一些基于移植的代码

while( t_compidx < n )
{
//选择计算前100种特征
t_n = portion;
if( t_compidx < datan )
{
t_n = ( t_n < (datan - t_compidx) ) ? t_n : (datan - t_compidx);
t_data = data;
t_cstep = cstep;
t_sstep = sstep;
}
else
{

}

if( sorteddata != NULL )
{

}
else
{
/* have sorted indices */
switch( sortedtype )
{
case CV_16SC1:

//选择某个样本的某个特征值作为结点
for( ti = t_compidx; ti < MIN( sortedn, t_compidx + t_n ); ti++ )
{
if( findStumpThreshold_16s[stumperror](
t_data + ti * t_cstep, t_sstep,
wdata, wstep, ydata, ystep,
sorteddata + ti * sortedcstep, sortedsstep, sortedm,
&lerror, &rerror,
&threshold, &left, &right,
&sumw, &sumwy, &sumwyy ) )
{
optcompidx = ti;
}
}
break;

}
}
}}
这里datan代表的是一个检測窗体包括的特征数目。portion代表以多少的行为单位进行计算。每一个循环选取valCache中的portion行进行计算,应该是为了发挥并行计算的优势,假如设置了并行计算的宏的话。

findStumpThreshold_32[stumperror]是一个函数指针,利用这个函数我们能够选择某个样本的某个特征值作为决策树的一个结点。

这里结点分裂方法我选择的是最小残差和的方法。即统计利用某个特征值进行分类后,左右子树中类间的残差之和。最小残差和相应的特征值就是满足要求的结点。

OpenCV1.0中利用宏定义的方式实现了这个函数

#define ICV_DEF_FIND_STUMP_THRESHOLD( suffix, type, error )                              \
CV_BOOST_IMPL int icvFindStumpThreshold_##suffix(                                              \
uchar* data, size_t datastep,                                                    \
uchar* wdata, size_t wstep,                                                      \
uchar* ydata, size_t ystep,                                                      \
uchar* idxdata, size_t idxstep, int num,                                         \
float* lerror,                                                                   \
float* rerror,                                                                   \
float* threshold, float* left, float* right,                                     \
float* sumw, float* sumwy, float* sumwyy )                                       \
{                                                                                        \
int found = 0;                                                                       \
float wyl  = 0.0F;                                                                   \
float wl   = 0.0F;                                                                   \
float wyyl = 0.0F;                                                                   \
float wyr  = 0.0F;                                                                   \
float wr   = 0.0F;                                                                   \
\
float curleft  = 0.0F;                                                               \
float curright = 0.0F;                                                               \
float* prevval = NULL;                                                               \
float* curval  = NULL;                                                               \
float curlerror = 0.0F;                                                              \
float currerror = 0.0F;                                                              \
float wposl;                                                                         \
float wposr;                                                                         \
\
int i = 0;                                                                           \
int idx = 0;                                                                         \
\
wposl = wposr = 0.0F;                                                                \
if( *sumw == FLT_MAX )                                                               \
{                                                                                    \
/* calculate sums */                                                             \
float *y = NULL;                                                                 \
float *w = NULL;                                                                 \
float wy = 0.0F;                                                                 \
\
*sumw   = 0.0F;                                                                  \
*sumwy  = 0.0F;                                                                  \
*sumwyy = 0.0F;                                                                  \
for( i = 0; i < num; i++ )                                                       \
{                                                                                \
idx = (int) ( *((type*) (idxdata + i*idxstep)) );                            \
w = (float*) (wdata + idx * wstep);                                          \
*sumw += *w;                                                                 \
y = (float*) (ydata + idx * ystep);                                          \
wy = (*w) * (*y);                                                            \
*sumwy += wy;                                                                \
*sumwyy += wy * (*y);                                                        \
}                                                                                \
}                                                                                    \
\
for( i = 0; i < num; i++ )                                                           \
{                                                                                    \
idx = (int) ( *((type*) (idxdata + i*idxstep)) );                                \
curval = (float*) (data + idx * datastep);                                       \
/* for debug purpose */                                                         \
if( i > 0 ) assert( (*prevval) <= (*curval) );                                   \
\
wyr  = *sumwy - wyl;                                                             \
wr   = *sumw  - wl;                                                              \
\
if( wl > 0.0 ) curleft = wyl / wl;                                               \
else curleft = 0.0F;                                                             \
\
if( wr > 0.0 ) curright = wyr / wr;                                              \
else curright = 0.0F;                                                            \
\
error                                                                            \
\
if( curlerror + currerror < (*lerror) + (*rerror) )                              \
{                                                                                \
(*lerror) = curlerror;                                                       \
(*rerror) = currerror;                                                       \
*threshold = *curval;                                                        \
if( i > 0 ) {                                                                \
*threshold = 0.5F * (*threshold + *prevval);                             \
}                                                                            \
*left  = curleft;                                                            \
*right = curright;                                                           \
found = 1;                                                                   \
}                                                                                \
\
do                                                                               \
{                                                                                \
wl  += *((float*) (wdata + idx * wstep));                                    \
wyl += (*((float*) (wdata + idx * wstep)))                                   \
* (*((float*) (ydata + idx * ystep)));                                   \
wyyl += *((float*) (wdata + idx * wstep))                                    \
* (*((float*) (ydata + idx * ystep)))                                    \
* (*((float*) (ydata + idx * ystep)));                                   \
}                                                                                \
while( (++i) < num &&                                                            \
( *((float*) (data + (idx =                                                  \
(int) ( *((type*) (idxdata + i*idxstep))) ) * datastep))                 \
== *curval ) );                                                          \
--i;                                                                             \
prevval = curval;                                                                \
} /* for each value */                                                               \
\
return found;                                                                        \
}


这里几个关键的变量:ydata是某个特征所代表的类别,正负样本分别以1。-1进行标注,wdata是正负样本相应的权值,data指的就是valCache的某一行。

程序进来的时候推断sumw是否初始化,没有初始化就进行赋值。因为同一个训练集每一个样本都仅仅相应一个ydata和wdata(每一个样本相应非常多个Haar特征,两者有差别),因此这里的sumw,sumwyy,sumwy都是一个确定的值。提前计算好,在后面的迭代中就不必反复计算。

接下来,依据idxCache中某一行(视迭代次数而定)的index,按从小到大的顺序遍历ValCache中相应行的特征值。也就是不相同本的同一特征值。并将其作为结点,尝试对样本进行划分。curleft和curright分别代表左右子树的类别的加权平均值。然后利用error宏计算左右子树的残差

#define ICV_DEF_FIND_STUMP_THRESHOLD_SQ( suffix, type )                                  \
ICV_DEF_FIND_STUMP_THRESHOLD( sq_##suffix, type,                                     \
/* calculate error (sum of squares)          */                                  \
/* err = sum( w * (y - left(rigt)Val)^2 )    */                                  \
curlerror = wyyl + curleft * curleft * wl - 2.0F * curleft * wyl;                \
currerror = (*sumwyy) - wyyl + curright * curright * wr - 2.0F * curright * wyr; \
)


最后的一个do-while循环。就是用来跳过和当前结点同样的特征值。尽管以兴许的、同样的值作为结点划分左右子树,残差平方和可能会改变,可是决策树划分的最小单位是特征值的种类,由于在利用决策树进行分类的时候,必须对同样的特征值做出一样的决策(该划入左子树还是该划入右子树)。

总结

HaarTraining的代码是有4,5K行。可是认真学习之后会收获非常多机器学习的算法和优秀代码的书写习惯。我会随着学习的深入不断更新自己的源代码研究体会,写的尽管不是非常具体,可是力求把重点突出来,将自己在阅读代码时碰到的困惑总结出来,给相同学习Training算法的朋友一点点帮助
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: