LibSVM C/C++
2015-12-04 21:17
591 查看
本系列文章由 @YhL_Leo 出品,转载请注明出处。
文章链接: http://blog.csdn.net/yhl_leo/article/details/50179779
在
1 训练模型的结构体
样本的类别通常使用
2 数据节点的结构体
数据组织结构如图1所示:
3 模型参数结构体
其中,各个参数的含义为:
SVM模型类型和核函数类型:
4 训练输出模型结构体
5 使用方法
以
将样本训练与预测进行改写:
用例Demo:
文章链接: http://blog.csdn.net/yhl_leo/article/details/50179779
在
LibSVM的库的
svm.h头文件中定义了四个主要结构体:
1 训练模型的结构体
struct svm_problem { int l; // total number of samples double *y; // label of each sample struct svm_node **x; // feature vector of each sample };
样本的类别通常使用
+1与
-1进行标识。如果样本的类别,则分类的准确率也就无法计算。
2 数据节点的结构体
struct svm_node { int index; double value; };
数据组织结构如图1所示:
3 模型参数结构体
struct svm_parameter { int svm_type; int kernel_type; int degree; /* for poly */ double gamma; /* for poly/rbf/sigmoid */ double coef0; /* for poly/sigmoid */ /* these are for training only */ double cache_size; /* in MB */ double eps; /* stopping criteria */ double C; /* for C_SVC, EPSILON_SVR and NU_SVR */ int nr_weight; /* for C_SVC */ int *weight_label; /* for C_SVC */ double* weight; /* for C_SVC */ double nu; /* for NU_SVC, ONE_CLASS, and NU_SVR */ double p; /* for EPSILON_SVR */ int shrinking; /* use the shrinking heuristics */ int probability; /* do probability estimates */ };
其中,各个参数的含义为:
-s svm_type : set type of SVM (default 0) 0 -- C-SVC 1 -- nu-SVC 2 -- one-class SVM 3 -- epsilon-SVR 4 -- nu-SVR -t kernel_type : set type of kernel function (default 2) 0 -- linear: u'*v 1 -- polynomial: (gamma*u'*v + coef0)^degree 2 -- radial basis function: exp(-gamma*|u-v|^2) 3 -- sigmoid: tanh(gamma*u'*v + coef0) -d degree : set degree in kernel function (default 3) -g gamma : set gamma in kernel function (default 1/num_features) -r coef0 : set coef0 in kernel function (default 0) -c cost : set the parameter C of C-SVC, epsilon-SVR, and nu-SVR (default 1) -n nu : set the parameter nu of nu-SVC, one-class SVM, and nu-SVR (default 0.5) -p epsilon : set the epsilon in loss function of epsilon-SVR (default 0.1) -m cachesize : set cache memory size in MB (default 100) -e epsilon : set tolerance of termination criterion (default 0.001) -h shrinking: whether to use the shrinking heuristics, 0 or 1 (default 1) -b probability_estimates: whether to train a SVC or SVR model for probability estimates, 0 or 1 (default 0) -wi weight: set the parameter C of class i to weight*C, for C-SVC (default 1)
SVM模型类型和核函数类型:
enum { C_SVC, NU_SVC, ONE_CLASS, EPSILON_SVR, NU_SVR }; /* svm_type */ enum { LINEAR, POLY, RBF, SIGMOID, PRECOMPUTED }; /* kernel_type */
4 训练输出模型结构体
struct svm_model { struct svm_parameter param; /* parameter */ int nr_class; /* number of classes, = 2 in regression/one class svm */ int l; /* total #SV */ struct svm_node **SV; /* SVs (SV[l]) */ double **sv_coef; /* coefficients for SVs in decision functions (sv_coef[k-1][l]) */ double *rho; /* constants in decision functions (rho[k*(k-1)/2]) */ double *probA; /* pariwise probability information */ double *probB; int *sv_indices; /* sv_indices[0,...,nSV-1] are values in [1,...,num_traning_data] to indicate SVs in the training set */ /* for classification only */ int *label; /* label of each class (label[k]) */ int *nSV; /* number of SVs for each class (nSV[k]) */ /* nSV[0] + nSV[1] + ... + nSV[k-1] = l */ /* XXX */ int free_sv; /* 1 if svm_model is created by svm_load_model*/ /* 0 if svm_model is created by svm_train */ };
5 使用方法
以
LibSVM提供的样本特征集
heart_scale为例,首先需要读取样本特征数据,可以利用
svm-train.c文件中的
read_problem函数,为了方便使用,对其进行了重写改写:
// TrainingDataLoad.h /* Load training data from svm format file. - Editor: Yahui Liu. - Data: 2015-11-30 - Email: yahui.cvrs@gmail.com - Address: Computer Vision and Remote Sensing(CVRS), Lab. **/ #ifndef TRAINING_DATA_LOAD_H #define TRAINING_DATA_LOAD_H #pragma once #include <stdio.h> #include <stdlib.h> #include <ctype.h> #include <iostream> #include <vector> #include <string> #include <fstream> #include <errno.h> #include "svm.h" //#include "svm-scale.c" using namespace std; #define MAX_LINE_LEN 1024 class TrainingDateLoad { public: TrainingDateLoad() { line = NULL; } ~TrainingDateLoad() { line = NULL; } public: char* line; // public: // static struct svm_parameter _paramInit; public: /*! load svm model */ void loadModel( std::string filename, struct svm_model*& model); /*! skip the target */ void svmSkipTarget( char*& p); /* skip the element */ void svmSkipElement( char*& p); void initialParams( struct svm_parameter& param ); /*! load training data */ void readProblem( std::string filename, struct svm_problem& prob, struct svm_parameter& param ); char* readline(FILE *input); void exit_input_error(int line_num) { cout << "Wrong input format at line: " << line_num << endl; exit(1); } }; #endif // TRAINING_DATA_LOAD_H
// TrainingDataLoad.cpp #include "TrainingDataLoad.h" void TrainingDateLoad::loadModel(std::string filename, struct svm_model*& model) { model = svm_load_model(filename.c_str()); } void TrainingDateLoad::svmSkipTarget(char*& p) { while(isspace(*p)) ++p; while(!isspace(*p)) ++p; } void TrainingDateLoad::svmSkipElement(char*& p) { while(*p!=':') ++p; ++p; while(isspace(*p)) ++p; while(*p && !isspace(*p)) ++p; } void TrainingDateLoad::initialParams( struct svm_parameter& param ) { // default values param.svm_type = C_SVC; param.kernel_type = RBF; param.degree = 3; param.gamma = 0; // 1/num_features param.coef0 = 0; param.nu = 0.5; param.cache_size = 100; param.C = 1; param.eps = 1e-3; param.p = 0.1; param.shrinking = 1; param.probability = 0; param.nr_weight = 0; param.weight_label = NULL; param.weight = NULL; } void TrainingDateLoad::readProblem( std::string filename, struct svm_problem& prob, struct svm_parameter& param ) { int max_index, inst_max_index, i; size_t elements, j; FILE *fp = fopen(filename.c_str(),"r"); char *endptr; char *idx, *val, *label; if(fp == NULL) { fprintf(stderr,"can't open input file %s\n",filename); exit(1); } prob.l = 0; elements = 0; line = new char[MAX_LINE_LEN]; while(readline(fp)!=NULL) { char *p = strtok(line," \t"); // label // features while(1) { p = strtok(NULL," \t"); if(p == NULL || *p == '\n') // check '\n' as ' ' may be after the last feature break; ++elements; } ++elements; ++prob.l; } rewind(fp); prob.y = new double[prob.l]; prob.x = new struct svm_node *[prob.l]; struct svm_node *x_space = new struct svm_node[elements]; max_index = 0; j=0; for(i=0;i<prob.l;i++) { inst_max_index = -1; // strtol gives 0 if wrong format, and precomputed kernel has <index> start from 0 readline(fp); prob.x[i] = &x_space[j]; label = strtok(line," \t\n"); if(label == NULL) // empty line exit_input_error(i+1); prob.y[i] = strtod(label,&endptr); if(endptr == label || *endptr != '\0') exit_input_error(i+1); while(1) { idx = strtok(NULL,":"); val = strtok(NULL," \t"); if(val == NULL) break; errno = 0; x_space[j].index = (int) strtol(idx,&endptr,10); if(endptr == idx || errno != 0 || *endptr != '\0' || x_space[j].index <= inst_max_index) exit_input_error(i+1); else inst_max_index = x_space[j].index; errno = 0; x_space[j].value = strtod(val,&endptr); if(endptr == val || errno != 0 || (*endptr != '\0' && !isspace(*endptr))) exit_input_error(i+1); ++j; } if(inst_max_index > max_index) max_index = inst_max_index; x_space[j++].index = -1; } if(param.gamma == 0 && max_index > 0) param.gamma = 1.0/max_index; if(param.kernel_type == PRECOMPUTED) for(i=0;i<prob.l;i++) { if (prob.x[i][0].index != 0) { fprintf(stderr,"Wrong input format: first column must be 0:sample_serial_number\n"); exit(1); } if ((int)prob.x[i][0].value <= 0 || (int)prob.x[i][0].value > max_index) { fprintf(stderr,"Wrong input format: sample_serial_number out of range\n"); exit(1); } } fclose(fp); } char* TrainingDateLoad::readline(FILE *input) { int len; if(fgets(line,MAX_LINE_LEN,input) == NULL) return NULL; int max_line_len = MAX_LINE_LEN; while(strrchr(line,'\n') == NULL) { max_line_len *= 2; line = (char *) realloc(line,max_line_len); len = (int) strlen(line); if(fgets(line+len,max_line_len-len,input) == NULL) break; } return line; }
将样本训练与预测进行改写:
// LibSVMTools.h /* LibSVM train and predict tools. - Editor: Yahui Liu. - Data: 2015-12-3 - Email: yahui.cvrs@gmail.com - Address: Computer Vision and Remote Sensing(CVRS), Lab. **/ #ifndef LIBSVM_TOOL_H #define LIBSVM_TOOL_H #pragma once #include <iostream> #include <string> #include "svm.h" #include "TrainingDataLoad.h" class LibSVMTools { public: LibSVMTools(){} ~LibSVMTools(){} public: /*! - featureFile: features of images saved in libsvm format. - saveModelFile: save the trained model file. **/ void libSvmTrain(std::string featureFile, std::string saveModelFile); /*! - featureFile: features of images saved in libsvm format. - modelFile: libsvm trained model. - savePredictFile: save the predicting results. **/ void libSvmPredict(std::string featureFile, std::string modelFile, std::string savePredictFile); }; #endif // LIBSVM_TOOL_H
// LibSVMTools.cpp #include "LibSVMTools.h" void LibSVMTools::libSvmTrain(std::string featureFile, std::string saveModelFile) { struct svm_parameter param; struct svm_problem prob; TrainingDateLoad* trainData = new TrainingDateLoad; trainData->initialParams( param ); trainData->readProblem(featureFile, prob, param); const char*errorMsg = svm_check_parameter(&prob, ¶m); if ( errorMsg ) { cout << errorMsg << endl; return; } struct svm_model *model = svm_train(&prob, ¶m); #if 1 cout << "svm_type: " << model->param.svm_type << endl << "kernel_type: " << model->param.kernel_type << endl << "gamma: " << model->param.gamma << endl << "nr_class: " << model->nr_class << endl << "total_sv: " << model->l << endl << "rho: " << model->rho[0] << endl << "label: " << model->label[0] << " " << model->label[1] << endl << "nr_sv: " << model->nSV[0] << " " << model->nSV[1] << endl; #endif int saveModel = svm_save_model( saveModelFile.c_str(), model ); } void LibSVMTools::libSvmPredict(std::string featureFile, std::string modelFile, std::string savePredictFile) { struct svm_parameter param; struct svm_problem prob; TrainingDateLoad * trainData = new TrainingDateLoad; trainData->initialParams( param ); trainData->readProblem(featureFile, prob, param); struct svm_model* model; trainData->loadModel(modelFile.c_str(), model); float correct(0.0); // all correct float uncorrect_1(0.0); // pos to neg float uncorrect_2(0.0); // neg to pos if ( prob.l ) { const int nCount = prob.l;; ofstream outfile( savePredictFile, ios::out ); for( int i=0; i<nCount; i++ ) { double label = svm_predict(model, prob.x[i]); if ( label == prob.y[i] ) { correct ++; } else if ( label == -1.0 ) { uncorrect_1 ++; } else { uncorrect_2 ++; } outfile << label << endl; } #if 1 cout << "total data count: " << nCount << endl << "classification correct: " << correct << endl << "pos to neg count: " << uncorrect_1 << endl << "neg to pos count: " << uncorrect_2 << endl; cout << "Accuracy: " << static_cast<float>(correct/nCount) << "(" << correct << "/" << nCount << ")" << endl; #endif outfile.close(); } }
用例Demo:
// train #include "LibSVMTools.h" void main() { std::cout << "************************************************************" << endl << "** PROGRAM: LibSVM model training. **" << endl << "** **" << endl << "** Author: Yahui Liu. **" << endl << "** School of Remote Sensing & Inf. Eng. **" << endl << "** Wuhan University, Hubei, P.R. China **" << endl << "** Email: yahui.cvrs@gmail.com **" << endl << "** Create time: Dec. 1, 2015 **" << endl << "************************************************************" << endl; string filename = "..\\..\\..\\Data\\heat_scale"; std::string savefielname = "..\\..\\..\\Data\\train.model"; LibSVMTools* libsvm = new LibSVMTools(); libsvm->libSvmTrain(filename, savefielname); delete libsvm; } /*------------------------------------------------------------------------------------*/ // predict #include "LibSVMTools.h" void main() { std::cout << "************************************************************" << endl << "** PROGRAM: LibSVM predict. **" << endl << "** **" << endl << "** Author: Yahui Liu. **" << endl << "** School of Remote Sensing & Inf. Eng. **" << endl << "** Wuhan University, Hubei, P.R. China **" << endl << "** Email: yahui.cvrs@gmail.com **" << endl << "** Create time: Dec. 1, 2015 **" << endl << "************************************************************" << endl; std::string featureFile = "..\\..\\..\\Data\\heart_scale"; std::string modelFile = "..\\..\\..\\Data\\train.model"; std::string savePredictFile = "..\\..\\..\\Data\\predict.out"; LibSVMTools* libsvm = new LibSVMTools(); libsvm->libSvmPredict(featureFile, modelFile, savePredictFile); delete libsvm; }
相关文章推荐
- LibSVM C/C++
- 复杂链表的复制(C++)
- Python实例浅谈之三Python与C/C++相互调用
- c++单独编译
- Python调用C/C++初步
- C++创建和使用DLL
- C++标准库——cmath,climits,limits
- 在 C++ 代码中使用 UE4 插件---Using a plugin in C++ code
- Thrift C++ Server demo 实现
- c语言的隐式(自动)数据类型转换
- c++11,for,for each,std::for_each的应用
- c语言中gets ,getschar 和fgets 的用法及三者之间的差别
- c++ 中的观察者模式
- 【转载】c++之类的基本操作(c++ primer 的读书笔记 ,类对象, 类用户, 类成员的含义)
- VS开发】C中调用C++文件中定义的function函数
- 学习笔记——关于指向常量指针的解引用的发现
- C++直接初始化和复制初始化2
- C++直接初始化和复制初始化1
- C语言之字符串
- 【C语言提高24】二级指针做输入的第一种内存模型:数组指针