您的位置:首页 > 运维架构

学习SVM(一) SVM模型训练与分类的OpenCV实现

2017-03-29 21:47 513 查看
学习SVM(一) SVM模型训练与分类的OpenCV实现

学习SVM(二) 如何理解支持向量机的最大分类间隔

学习SVM(三)理解SVM中的对偶问题

学习SVM(四) 理解SVM中的支持向量(Support Vector)

学习SVM(五)理解线性SVM的松弛因子

Andrew Ng 在斯坦福大学的机器学习公开课上这样评价支持向量机:

support vector machines is the supervised learning algorithm that many people consider the most effective off-the-shelf supervised learning algorithm.That point of view is debatable,but there are many people that hold that point of view.

可见,在监督学习算法中支持向量机有着非常广泛的应用,而且在解决图像分类问题时有着优异的效果。

Opencv集成了这种学习算法,它被包含在ml模块下的CvSVM类中,下面我们用Opencv实现SVM的模型训练加载模型实现分类,为了理解起来更加直观,我们将这两个部分用两个工程来实现。

模型训练

#include <stdio.h>
#include <time.h>
#include <opencv2/opencv.hpp>
#include <opencv/cv.h>
#include <iostream>
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/ml/ml.hpp>
#include <io.h>

using namespace std;
using namespace cv;

void getFiles( string path, vector<string>& files);
void getBubble(Mat& trainingImages, vector<int>& trainingLabels);
void getNoBubble(Mat& trainingImages, vector<int>& trainingLabels);

int main()
{
//获取训练数据
Mat classes;
Mat trainingData;
Mat trainingImages;
vector<int> trainingLabels;
getBubble(trainingImages, trainingLabels);
getNoBubble(trainingImages, trainingLabels);
Mat(trainingImages).copyTo(trainingData);
trainingData.convertTo(trainingData, CV_32FC1);
Mat(trainingLabels).copyTo(classes);
//配置SVM训练器参数
CvSVMParams SVM_params;
SVM_params.svm_type = CvSVM::C_SVC;
SVM_params.kernel_type = CvSVM::LINEAR;
SVM_params.degree = 0;
SVM_params.gamma = 1;
SVM_params.coef0 = 0;
SVM_params.C = 1;
SVM_params.nu = 0;
SVM_params.p = 0;
SVM_params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER, 1000, 0.01);
//训练
CvSVM svm;
svm.train(trainingData, classes, Mat(), Mat(), SVM_params);
//保存模型
svm.save("svm.xml");
cout<<"训练好了!!!"<<endl;
getchar();
return 0;
}
void getFiles( string path, vector<string>& files )
{
long   hFile   =   0;
struct _finddata_t fileinfo;
string p;
if((hFile = _findfirst(p.assign(path).append("\\*").c_str(),&fileinfo)) !=  -1)
{
do
{
if((fileinfo.attrib &  _A_SUBDIR))
{
if(strcmp(fileinfo.name,".") != 0  &&  strcmp(fileinfo.name,"..") != 0)
getFiles( p.assign(path).append("\\").append(fileinfo.name), files );
}
else
{
files.push_back(p.assign(path).append("\\").append(fileinfo.name) );
}
}while(_findnext(hFile, &fileinfo)  == 0);

_findclose(hFile);
}
}
void getBubble(Mat& trainingImages, vector<int>& trainingLabels)
{
char * filePath = "D:\\train\\has\\train";
vector<string> files;
getFiles(filePath, files );
int number = files.size();
for (int i = 0;i < number;i++)
{
Mat  SrcImage=imread(files[i].c_str());
SrcImage= SrcImage.reshape(1, 1);
trainingImages.push_back(SrcImage);
trainingLabels.push_back(1);
}
}
void getNoBubble(Mat& trainingImages, vector<int>& trainingLabels)
{
char * filePath = "D:\\train\\no\\train";
vector<string> files;
getFiles(filePath, files );
int number = files.size();
for (int i = 0;i < number;i++)
{
Mat  SrcImage=imread(files[i].c_str());
SrcImage= SrcImage.reshape(1, 1);
trainingImages.push_back(SrcImage);
trainingLabels.push_back(0);
}
}


整个训练过程可以分为一下几个部分:

数据准备:

该例程中一个定义了三个子程序用来实现数据准备工作:

getFiles()用来遍历文件夹下所有文件,可以参考:

http://blog.csdn.net/chaipp0607/article/details/53914954

getBubble()用来获取有气泡的图片和与其对应的Labels,该例程将Labels定为1。

getNoBubble()用来获取没有气泡的图片与其对应的Labels,该例程将Labels定为0。

getBubble()与getNoBubble()将获取一张图片后会将图片(特征)写入到容器中,紧接着会将标签写入另一个容器中,这样就保证了特征和标签是一一对应的关系
push_back(0)
或者
push_back(1)
其实就是我们贴标签的过程。

trainingImages.push_back(SrcImage);
trainingLabels.push_back(0);


在主函数中,将getBubble()与getNoBubble()写好的包含特征的矩阵拷贝给trainingData,将包含标签的vector容器进行类型转换后拷贝到trainingLabels里,至此,数据准备工作完成,trainingData与trainingLabels就是我们要训练的数据。

Mat classes;
Mat trainingData;
Mat trainingImages;
vector<int> trainingLabels;
getBubble(trainingImages, trainingLabels);
getNoBubble(trainingImages, trainingLabels);
Mat(trainingImages).copyTo(trainingData);
trainingData.convertTo(trainingData, CV_32FC1);
Mat(trainingLabels).copyTo(classes);


特征选取

其实特征提取和数据的准备是同步完成的,我们最后要训练的也是正负样本的特征。本例程中同样在getBubble()与getNoBubble()函数中完成特征提取工作,只是我们简单粗暴将整个图的所有像素作为了特征,因为我们关注更多的是整个的训练过程,所以选择了最简单的方式完成特征提取工作,除此中外,特征提取的方式有很多,比如LBP,HOG等等。

SrcImage= SrcImage.reshape(1, 1);


我们利用reshape()函数完成特征提取,原型如下:

Mat reshape(int cn, int rows=0) const;


可以看到该函数的参数非常简单,cn为新的通道数,如果cn = 0,表示通道数不会改变。参数rows为新的行数,如果rows = 0,表示行数不会改变。我们将参数定义为reshape(1, 1)的结果就是原图像对应的矩阵将被拉伸成一个一行的向量,作为特征向量。

参数配置

参数配置是SVM的核心部分,在Opencv中它被定义成一个结构体类型,如下:

struct CV_EXPORTS_W_MAP CvSVMParams
{
CvSVMParams();
CvSVMParams(
int svm_type,
int kernel_type,
double degree,
double coef0,
double Cvalue,
double p,
CvMat* class_weights,
CvTermCriteria term_crit );
CV_PROP_RW int         svm_type;
CV_PROP_RW int         kernel_type;
CV_PROP_RW double      degree; // for poly
CV_PROP_RW double      gamma;  // for poly/rbf/sigmoid
CV_PROP_RW double      coef0;  // for poly/sigmoid
CV_PROP_RW double      C;  // for CV_SVM_C_SVC,       CV_SVM_EPS_SVR and CV_SVM_NU_SVR
CV_PROP_RW double      nu; // for CV_SVM_NU_SVC, CV_SVM_ONE_CLASS, and CV_SVM_NU_SVR
CV_PROP_RW double      p; // for CV_SVM_EPS_SVR
CvMat*      class_weights; // for CV_SVM_C_SVC
CV_PROP_RW CvTermCriteria term_crit; // termination criteria
};


所以在例程中我们定义了一个结构体变量用来配置这些参数,而这个变量也就是CVSVM类中train函数的第五个参数,下面对参数进行说明。

SVM_params.svm_type
:SVM的类型:

C_SVC
表示SVM分类器,
C_SVR
表示SVM回归

SVM_params.kernel_type
:核函数类型

线性核
LINEAR
:

d(x,y)=(x,y)

多项式核
POLY
:

d(x,y)=(gamma*(x’y)+coef0)degree

径向基核
RBF
:

d(x,y)=exp(-gamma*|x-y|^2)

sigmoid核
SIGMOID
:

d(x,y)= tanh(gamma*(x’y)+ coef0)

SVM_params.degree:核函数中的参数degree,针对多项式核函数;

SVM_params.gama:核函数中的参数gamma,针对多项式/RBF/SIGMOID核函数;

SVM_params.coef0:核函数中的参数,针对多项式/SIGMOID核函数;

SVM_params.c:SVM最优问题参数,设置
C-SVC
EPS_SVR
NU_SVR
的参数;

SVM_params.nu:SVM最优问题参数,设置
NU_SVC
ONE_CLASS
NU_SVR
的参数;

SVM_params.p:SVM最优问题参数,设置E
PS_SVR
中损失函数p的值.

训练模型

CvSVM svm;
svm.train(trainingData, classes, Mat(), Mat(), SVM_params);


通过上面的过程,我们准备好了待训练的数据和训练需要的参数,其实可以理解为这个准备工作就是在为svm.train()函数准备实参的过程。来看一下svm.train()函数,Opencv将SVM封装成CvSVM库,这个库是基于台湾大学林智仁(Lin Chih-Jen)教授等人开发的LIBSVM封装的,由于篇幅限制,不再全部粘贴库的定义,所以一下代码只是CvSVM库中的一部分数据和函数:

class CV_EXPORTS_W CvSVM : public CvStatModel
{
public:
virtual bool train(
const CvMat* trainData,
const CvMat* responses,
const CvMat* varIdx=0,
const CvMat* sampleIdx=0,
CvSVMParams params=CvSVMParams() );
virtual float predict(
const CvMat* sample,
bool returnDFVal=false ) const;


我们就是应用类中定义的train函数完成模型训练工作。

保存模型

svm.save("svm.xml");


保存模型只有一行代码,利用save()函数,我们看下它的定义:

CV_WRAP virtual void save( const char* filename, const char* name=0 ) const;


该函数被定义在CvStatModel类中,CvStatModel是ML库中的统计模型基类,其他 ML 类都是从这个类中继承。

总结:到这里我们就完成了模型训练工作,可以看到真正用于训练的代码其实很少,OpenCV最支持向量机的封装极大地降低了我们的编程工作。

加载模型实现分类

#include <stdio.h>
#include <time.h>
#include <opencv2/opencv.hpp>
#include <opencv/cv.h>
#include <iostream>
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/ml/ml.hpp>
#include <io.h>

using namespace std;
using namespace cv;

void getFiles( string path, vector<string>& files );

int main()
{
int result = 0;
char * filePath = "D:\\train\\has\\test";
vector<string> files;
getFiles(filePath, files );
int number = files.size();
cout<<number<<endl;
CvSVM svm;
svm.clear();
string modelpath = "svm.xml";
FileStorage svm_fs(modelpath,FileStorage::READ);
if(svm_fs.isOpened())
{
svm.load(modelpath.c_str());
}
for (int i = 0;i < number;i++)
{
Mat inMat = imread(files[i].c_str());
Mat p = inMat.reshape(1, 1);
p.convertTo(p, CV_32FC1);
int response = (int)svm.predict(p);
if (response == 1)
{
result++;
}
}
cout<<result<<endl;
getchar();
return  0;
}
void getFiles( string path, vector<string>& files )
{
long   hFile   =   0;
struct _finddata_t fileinfo;
string p;
if((hFile = _findfirst(p.assign(path).append("\\*").c_str(),&fileinfo)) !=  -1)
{
do
{
if((fileinfo.attrib &  _A_SUBDIR))
{
if(strcmp(fileinfo.name,".") != 0  &&  strcmp(fileinfo.name,"..") != 0)
getFiles( p.assign(path).append("\\").append(fileinfo.name), files );
}
else
{       files.push_back(p.assign(path).append("\\").append(fileinfo.name) );
}
}while(_findnext(hFile, &fileinfo)  == 0);
_findclose(hFile);
}
}


在上面我们把该介绍的都说的差不多了,这个例程中只是用到了load()函数用于模型加载,加载的就是上面例子中生成的模型,load()被定义在CvStatModel这个基类中,然后用到predict()函数用来预测分类结果,predict()被定义在CVSVM类中。

结果:

测试了517张正样本,1815张负样本,但是只成功预测了216和1481。这个准确率很差,因为我们并没有针对样本做合适的特征提取工作。
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: