您的位置:首页 > 编程语言 > C语言/C++

LibSVM C/C++

2015-12-04 21:17 591 查看
本系列文章由 @YhL_Leo 出品,转载请注明出处。

文章链接: http://blog.csdn.net/yhl_leo/article/details/50179779

LibSVM
的库的
svm.h
头文件中定义了四个主要结构体:

1 训练模型的结构体

struct svm_problem
{
int l;                // total number of samples
double *y;            // label of each sample
struct svm_node **x;  // feature vector of each sample
};


样本的类别通常使用
+1
-1
进行标识。如果样本的类别,则分类的准确率也就无法计算。

2 数据节点的结构体

struct svm_node
{
int index;
double value;
};


数据组织结构如图1所示:



3 模型参数结构体

struct svm_parameter
{
int svm_type;
int kernel_type;
int degree; /* for poly */
double gamma;   /* for poly/rbf/sigmoid */
double coef0;   /* for poly/sigmoid */

/* these are for training only */
double cache_size; /* in MB */
double eps; /* stopping criteria */
double C;   /* for C_SVC, EPSILON_SVR and NU_SVR */
int nr_weight;      /* for C_SVC */
int *weight_label;  /* for C_SVC */
double* weight;     /* for C_SVC */
double nu;  /* for NU_SVC, ONE_CLASS, and NU_SVR */
double p;   /* for EPSILON_SVR */
int shrinking;  /* use the shrinking heuristics */
int probability; /* do probability estimates */
};


其中,各个参数的含义为:

-s svm_type : set type of SVM (default 0)
0 -- C-SVC
1 -- nu-SVC
2 -- one-class SVM
3 -- epsilon-SVR
4 -- nu-SVR
-t kernel_type : set type of kernel function (default 2)
0 -- linear: u'*v
1 -- polynomial: (gamma*u'*v + coef0)^degree
2 -- radial basis function: exp(-gamma*|u-v|^2)
3 -- sigmoid: tanh(gamma*u'*v + coef0)
-d degree : set degree in kernel function (default 3)
-g gamma : set gamma in kernel function (default 1/num_features)
-r coef0 : set coef0 in kernel function (default 0)
-c cost : set the parameter C of C-SVC, epsilon-SVR, and nu-SVR (default 1)
-n nu : set the parameter nu of nu-SVC, one-class SVM, and nu-SVR (default 0.5)
-p epsilon : set the epsilon in loss function of epsilon-SVR (default 0.1)
-m cachesize : set cache memory size in MB (default 100)
-e epsilon : set tolerance of termination criterion (default 0.001)
-h shrinking: whether to use the shrinking heuristics, 0 or 1 (default 1)
-b probability_estimates: whether to train a SVC or SVR model for probability estimates, 0 or 1 (default 0)
-wi weight: set the parameter C of class i to weight*C, for C-SVC (default 1)


SVM模型类型和核函数类型:

enum { C_SVC, NU_SVC, ONE_CLASS, EPSILON_SVR, NU_SVR }; /* svm_type */
enum { LINEAR, POLY, RBF, SIGMOID, PRECOMPUTED }; /* kernel_type */


4 训练输出模型结构体

struct svm_model
{
struct svm_parameter param; /* parameter */
int nr_class;       /* number of classes, = 2 in regression/one class svm */
int l;          /* total #SV */
struct svm_node **SV;       /* SVs (SV[l]) */
double **sv_coef;   /* coefficients for SVs in decision functions (sv_coef[k-1][l]) */
double *rho;        /* constants in decision functions (rho[k*(k-1)/2]) */
double *probA;      /* pariwise probability information */
double *probB;
int *sv_indices;        /* sv_indices[0,...,nSV-1] are values in [1,...,num_traning_data] to indicate SVs in the training set */

/* for classification only */

int *label;     /* label of each class (label[k]) */
int *nSV;       /* number of SVs for each class (nSV[k]) */
/* nSV[0] + nSV[1] + ... + nSV[k-1] = l */
/* XXX */
int free_sv;        /* 1 if svm_model is created by svm_load_model*/
/* 0 if svm_model is created by svm_train */
};


5 使用方法

LibSVM
提供的样本特征集
heart_scale
为例,首先需要读取样本特征数据,可以利用
svm-train.c
文件中的
read_problem
函数,为了方便使用,对其进行了重写改写:

// TrainingDataLoad.h
/*
Load training data from svm format file.

- Editor: Yahui Liu.
- Data:   2015-11-30
- Email:  yahui.cvrs@gmail.com
- Address: Computer Vision and Remote Sensing(CVRS), Lab.
**/

#ifndef TRAINING_DATA_LOAD_H
#define TRAINING_DATA_LOAD_H
#pragma once

#include <stdio.h>
#include <stdlib.h>
#include <ctype.h>
#include <iostream>
#include <vector>
#include <string>
#include <fstream>
#include <errno.h>

#include "svm.h"
//#include "svm-scale.c"

using namespace std;

#define MAX_LINE_LEN 1024

class TrainingDateLoad
{
public:
TrainingDateLoad()
{
line = NULL;
}

~TrainingDateLoad()
{
line = NULL;
}

public:
char* line;

// public:
//  static struct svm_parameter _paramInit;

public:

/*! load svm model */
void loadModel( std::string filename,  struct svm_model*& model);
/*! skip the target */
void svmSkipTarget( char*& p);
/* skip the element */
void svmSkipElement( char*& p);

void initialParams( struct svm_parameter& param );
/*! load training data */
void readProblem( std::string filename, struct svm_problem& prob, struct svm_parameter& param );

char* readline(FILE *input);

void exit_input_error(int line_num)
{
cout << "Wrong input format at line: " << line_num << endl;
exit(1);
}

};

#endif // TRAINING_DATA_LOAD_H


// TrainingDataLoad.cpp
#include "TrainingDataLoad.h"

void TrainingDateLoad::loadModel(std::string filename, struct svm_model*& model)
{
model = svm_load_model(filename.c_str());
}

void TrainingDateLoad::svmSkipTarget(char*& p)
{
while(isspace(*p)) ++p;

while(!isspace(*p)) ++p;
}

void TrainingDateLoad::svmSkipElement(char*& p)
{
while(*p!=':') ++p;

++p;
while(isspace(*p)) ++p;
while(*p && !isspace(*p)) ++p;
}

void TrainingDateLoad::initialParams( struct svm_parameter& param )
{
// default values
param.svm_type = C_SVC;
param.kernel_type = RBF;
param.degree = 3;
param.gamma = 0;    // 1/num_features
param.coef0 = 0;
param.nu = 0.5;
param.cache_size = 100;
param.C = 1;
param.eps = 1e-3;
param.p = 0.1;
param.shrinking = 1;
param.probability = 0;
param.nr_weight = 0;
param.weight_label = NULL;
param.weight = NULL;
}

void TrainingDateLoad::readProblem( std::string filename,
struct svm_problem& prob, struct svm_parameter& param )
{
int max_index, inst_max_index, i;
size_t elements, j;
FILE *fp = fopen(filename.c_str(),"r");
char *endptr;
char *idx, *val, *label;

if(fp == NULL)
{
fprintf(stderr,"can't open input file %s\n",filename);
exit(1);
}

prob.l = 0;
elements = 0;

line = new char[MAX_LINE_LEN];
while(readline(fp)!=NULL)
{
char *p = strtok(line," \t"); // label

// features
while(1)
{
p = strtok(NULL," \t");
if(p == NULL || *p == '\n') // check '\n' as ' ' may be after the last feature
break;
++elements;
}
++elements;
++prob.l;
}
rewind(fp);

prob.y = new double[prob.l];
prob.x = new struct svm_node *[prob.l];
struct svm_node *x_space = new struct svm_node[elements];

max_index = 0;
j=0;
for(i=0;i<prob.l;i++)
{
inst_max_index = -1; // strtol gives 0 if wrong format, and precomputed kernel has <index> start from 0
readline(fp);
prob.x[i] = &x_space[j];
label = strtok(line," \t\n");
if(label == NULL) // empty line
exit_input_error(i+1);

prob.y[i] = strtod(label,&endptr);
if(endptr == label || *endptr != '\0')
exit_input_error(i+1);

while(1)
{
idx = strtok(NULL,":");
val = strtok(NULL," \t");

if(val == NULL)
break;

errno = 0;
x_space[j].index = (int) strtol(idx,&endptr,10);
if(endptr == idx || errno != 0 || *endptr != '\0' || x_space[j].index <= inst_max_index)
exit_input_error(i+1);
else
inst_max_index = x_space[j].index;

errno = 0;
x_space[j].value = strtod(val,&endptr);
if(endptr == val || errno != 0 || (*endptr != '\0' && !isspace(*endptr)))
exit_input_error(i+1);

++j;
}

if(inst_max_index > max_index)
max_index = inst_max_index;
x_space[j++].index = -1;
}

if(param.gamma == 0 && max_index > 0)
param.gamma = 1.0/max_index;

if(param.kernel_type == PRECOMPUTED)
for(i=0;i<prob.l;i++)
{
if (prob.x[i][0].index != 0)
{
fprintf(stderr,"Wrong input format: first column must be 0:sample_serial_number\n");
exit(1);
}
if ((int)prob.x[i][0].value <= 0 || (int)prob.x[i][0].value > max_index)
{
fprintf(stderr,"Wrong input format: sample_serial_number out of range\n");
exit(1);
}
}

fclose(fp);
}

char* TrainingDateLoad::readline(FILE *input)
{
int len;
if(fgets(line,MAX_LINE_LEN,input) == NULL)
return NULL;

int max_line_len = MAX_LINE_LEN;
while(strrchr(line,'\n') == NULL)
{
max_line_len *= 2;
line = (char *) realloc(line,max_line_len);
len = (int) strlen(line);
if(fgets(line+len,max_line_len-len,input) == NULL)
break;
}
return line;
}


将样本训练与预测进行改写:

// LibSVMTools.h
/*
LibSVM train and predict tools.

- Editor: Yahui Liu.
- Data:   2015-12-3
- Email:  yahui.cvrs@gmail.com
- Address: Computer Vision and Remote Sensing(CVRS), Lab.
**/

#ifndef LIBSVM_TOOL_H
#define LIBSVM_TOOL_H
#pragma once

#include <iostream>
#include <string>

#include "svm.h"
#include "TrainingDataLoad.h"

class LibSVMTools
{
public:
LibSVMTools(){}
~LibSVMTools(){}

public:
/*!
- featureFile: features of images saved in libsvm format.
- saveModelFile: save the trained model file.
**/
void libSvmTrain(std::string featureFile, std::string saveModelFile);

/*!
- featureFile: features of images saved in libsvm format.
- modelFile: libsvm trained model.
- savePredictFile: save the predicting results.
**/
void libSvmPredict(std::string featureFile, std::string modelFile, std::string savePredictFile);
};

#endif // LIBSVM_TOOL_H


// LibSVMTools.cpp
#include "LibSVMTools.h"

void LibSVMTools::libSvmTrain(std::string featureFile, std::string saveModelFile)
{
struct svm_parameter param;
struct svm_problem prob;

TrainingDateLoad* trainData = new TrainingDateLoad;
trainData->initialParams( param );
trainData->readProblem(featureFile, prob, param);

const char*errorMsg = svm_check_parameter(&prob, ¶m);
if ( errorMsg )
{
cout << errorMsg << endl;
return;
}

struct svm_model *model = svm_train(&prob, ¶m);

#if 1
cout << "svm_type: " << model->param.svm_type << endl <<
"kernel_type: " << model->param.kernel_type << endl <<
"gamma: " << model->param.gamma << endl <<
"nr_class: " << model->nr_class << endl <<
"total_sv: " << model->l << endl <<
"rho: " << model->rho[0] << endl <<
"label: " << model->label[0] << " " << model->label[1] << endl <<
"nr_sv: " << model->nSV[0] << " " << model->nSV[1] << endl;
#endif

int saveModel = svm_save_model( saveModelFile.c_str(), model );
}

void LibSVMTools::libSvmPredict(std::string featureFile,
std::string modelFile, std::string savePredictFile)
{
struct svm_parameter param;
struct svm_problem prob;

TrainingDateLoad * trainData = new TrainingDateLoad;
trainData->initialParams( param );
trainData->readProblem(featureFile, prob, param);

struct svm_model* model;
trainData->loadModel(modelFile.c_str(), model);

float correct(0.0);     // all correct
float uncorrect_1(0.0); // pos to neg
float uncorrect_2(0.0); // neg to pos
if ( prob.l )
{
const int nCount = prob.l;;

ofstream outfile( savePredictFile, ios::out );
for( int i=0; i<nCount; i++ )
{
double label = svm_predict(model, prob.x[i]);
if ( label == prob.y[i] )
{
correct ++;
}
else if ( label == -1.0 )
{
uncorrect_1 ++;
}
else
{
uncorrect_2 ++;
}
outfile << label << endl;
}
#if 1
cout << "total data count: " << nCount << endl <<
"classification correct: " << correct << endl <<
"pos to neg count: " << uncorrect_1 << endl <<
"neg to pos count: " << uncorrect_2 << endl;

cout << "Accuracy: " << static_cast<float>(correct/nCount)
<< "(" << correct << "/" << nCount << ")" << endl;
#endif
outfile.close();
}
}


用例Demo:

// train
#include "LibSVMTools.h"

void main()
{
std::cout <<
"************************************************************" << endl <<
"**          PROGRAM: LibSVM model training.               **" << endl <<
"**                                                        **" << endl <<
"**           Author: Yahui Liu.                           **" << endl <<
"**                   School of Remote Sensing & Inf. Eng. **" << endl <<
"**                   Wuhan University, Hubei, P.R. China  **" << endl <<
"**            Email: yahui.cvrs@gmail.com                 **" << endl <<
"**      Create time: Dec. 1, 2015                         **" << endl <<
"************************************************************" << endl;

string filename = "..\\..\\..\\Data\\heat_scale";
std::string savefielname = "..\\..\\..\\Data\\train.model";

LibSVMTools* libsvm = new LibSVMTools();
libsvm->libSvmTrain(filename, savefielname);

delete libsvm;
}

/*------------------------------------------------------------------------------------*/

// predict
#include "LibSVMTools.h"

void main()
{
std::cout <<
"************************************************************" << endl <<
"**          PROGRAM: LibSVM predict.                      **" << endl <<
"**                                                        **" << endl <<
"**           Author: Yahui Liu.                           **" << endl <<
"**                   School of Remote Sensing & Inf. Eng. **" << endl <<
"**                   Wuhan University, Hubei, P.R. China  **" << endl <<
"**            Email: yahui.cvrs@gmail.com                 **" << endl <<
"**      Create time: Dec. 1, 2015                         **" << endl <<
"************************************************************" << endl;

std::string featureFile = "..\\..\\..\\Data\\heart_scale";
std::string modelFile = "..\\..\\..\\Data\\train.model";
std::string savePredictFile = "..\\..\\..\\Data\\predict.out";

LibSVMTools* libsvm = new LibSVMTools();
libsvm->libSvmPredict(featureFile, modelFile, savePredictFile);

delete libsvm;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: