您的位置:首页 > 编程语言 > C语言/C++

[code segments] OpenCV3.0 SVM with C++ interface

2017-06-18 14:27 363 查看
talk is cheap, show you the code:

/************************************************************************/
/* Name   : OpenCV SVM test                                             */
/* Date   : 2015/11/7                                                   */
/* Author : aban                                                        */
/************************************************************************/
// note : the code is modified from internet.

#include <iostream>
#include <cmath>
#include <string>
using namespace std;

#include <opencv2/opencv.hpp>
#include <opencv2/ml.hpp>
using namespace cv;

bool plotSupportVectors = true;
int numTrainingPoints = 200;
int numTestPoints = 2000;
int size = 200;
int eq = 0;

// accuracy
float evaluate(cv::Mat& predicted, cv::Mat& actual) {
assert(predicted.rows == actual.rows);
int t = 0;
int f = 0;
for (int i = 0; i < actual.rows; i++) {
float p = predicted.at<float>(i, 0);
float a = actual.at<float>(i, 0);
if ((p >= 0.0 && a >= 0.0) || (p <= 0.0 &&  a <= 0.0)) {
t++;
}
else {
f++;
}
}
return (t * 1.0) / (t + f);
}

// plot data and class
void plot_binary(cv::Mat& data, cv::Mat& classes, string name) {
cv::Mat plot(size, size, CV_8UC3);
plot.setTo(cv::Scalar(255.0, 255.0, 255.0));
for (int i = 0; i < data.rows; i++) {

float x = data.at<float>(i, 0) * size;
float y = data.at<float>(i, 1) * size;

if (classes.at<float>(i, 0) > 0) {
cv::circle(plot, Point(x, y), 2, CV_RGB(255, 0, 0), 1);
}
else {
cv::circle(plot, Point(x, y), 2, CV_RGB(0, 255, 0), 1);
}
}
cv::namedWindow(name, CV_WINDOW_KEEPRATIO);
cv::imshow(name, plot);
}

// function to learn
int f(float x, float y, int equation) {
switch (equation) {
case 0:
return y > sin(x * 10) ?

-1 : 1;
break;
case 1:
return y > cos(x * 10) ? -1 : 1;
break;
case 2:
return y > 2 * x ?

-1 : 1;
break;
case 3:
return y > tan(x * 10) ?

-1 : 1;
break;
default:
return y > cos(x * 10) ?

-1 : 1;
}
}

// label data with equation
cv::Mat labelData(cv::Mat points, int equation) {
cv::Mat labels(points.rows, 1, CV_32FC1);
for (int i = 0; i < points.rows; i++) {
float x = points.at<float>(i, 0);
float y = points.at<float>(i, 1);
labels.at<float>(i, 0) = f(x, y, equation);
}
return labels;
}

void svm(cv::Mat& trainingData, cv::Mat& trainingClasses, cv::Mat& testData, cv::Mat& testClasses) {

Mat traning_label(trainingClasses.rows, 1, CV_32SC1);
for (int i = 0; i < trainingClasses.rows; i++){
traning_label.at<int>(i, 0) = trainingClasses.at<float>(i, 0);
}

cv::Ptr<cv::ml::SVM> svm = ml::SVM::create();
svm->setType(ml::SVM::Types::C_SVC);
svm->setKernel(ml::SVM::KernelTypes::RBF);
//svm->setDegree(0);  // for poly
svm->setGamma(20);  // for poly/rbf/sigmoid
//svm->setCoef0(0);   // for poly/sigmoid
svm->setC(7);       // for CV_SVM_C_SVC, CV_SVM_EPS_SVR and CV_SVM_NU_SVR
//svm->setNu(0);      // for CV_SVM_NU_SVC, CV_SVM_ONE_CLASS, and CV_SVM_NU_SVR
//svm->setP(0);       // for CV_SVM_EPS_SVR

svm->setTermCriteria(TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, 1000, 1E-6));

svm->train(trainingData, ml::SampleTypes::ROW_SAMPLE, traning_label);

cv::Mat predicted(testClasses.rows, 1, CV_32F);

svm->predict(testData, predicted);

cout << "Accuracy_{SVM} = " << evaluate(predicted, testClasses) << endl;
plot_binary(testData, predicted, "Predictions SVM");

// plot support vectors
if (plotSupportVectors) {
cv::Mat plot_sv(size, size, CV_8UC3);
plot_sv.setTo(cv::Scalar(255.0, 255.0, 255.0));

Mat support_vectors = svm->getSupportVectors();
for (int vecNum = 0; vecNum < support_vectors.rows; vecNum++){
cv::circle(plot_sv, Point(support_vectors.row(vecNum).at<float>(0)*size, support_vectors.row(vecNum).at<float>(1)*size), 3, CV_RGB(0, 0, 0));
}

namedWindow("Support Vectors", CV_WINDOW_KEEPRATIO);
cv::imshow("Support Vectors", plot_sv);
}
}

int main(){

cv::Mat trainingData(numTrainingPoints, 2, CV_32FC1);
cv::Mat testData(numTestPoints, 2, CV_32FC1);

cv::randu(trainingData, 0, 1);
cv::randu(testData, 0, 1);

cv::Mat trainingClasses = labelData(trainingData, eq);
cv::Mat testClasses = labelData(testData, eq);

plot_binary(trainingData, trainingClasses, "Training Data");
plot_binary(testData, testClasses, "Test Data");

svm(trainingData, trainingClasses, testData, testClasses);
waitKey(0);
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: