基于Opencv库中SVM模块的MNIST手写字识别数据库识别
2015-10-24 19:15
597 查看
基于Opencv库中SVM模块的MNIST手写字识别数据库识别代码。
MNIST的手写数字数据库,有60000例训练集, 10000个测试集。它是更大的数据集NIS
4000
T的一个子集。 数字已经被size-normalized,是有固定大小的图像。
官方地址:http://yann.lecun.com/exdb/mnist/
有对这个数据的详细介绍。这里提一下,数据他是二进制文件格式存储的。不是图片格式,所以需要注意其数据存放格式,在opencv中进行数据格式转换。
数据格式:
TRAINING SET LABEL FILE (train-labels-idx1-ubyte):
[offset] [type] [value] [description]
0000 32 bitinteger 0x00000801(2049) magic number (MSB first)
0004 32 bitinteger 60000 number of items
0008 unsigned byte ?? label
........
xxxx unsigned byte ?? label
The labels values are 0 to 9.
TRAINING SET IMAGE FILE (train-images-idx3-ubyte):
[offset] [type] [value] [description]
0000 32 bitinteger 0x00000803(2051) magicnumber
0004 32 bitinteger 60000 number of images
0008 32 bitinteger 28 number of rows
0012 32 bitinteger 28 number of columns
0016 unsigned byte ?? pixel
0017 unsigned byte ?? pixel
........
xxxx unsigned byte ?? pixel
SVM的识别错误率:
界面:
环境:opencv2.4+Ubuntu+linux
其中一个数据(已经被博主归一化了大小):
导入数据的输出提示:
模型训练提示输出:
测试集测试结果:线性核下正确率92.83%,低于上面网站上的的正确率,可能是参数没有设置好。
MNIST的手写数字数据库,有60000例训练集, 10000个测试集。它是更大的数据集NIS
4000
T的一个子集。 数字已经被size-normalized,是有固定大小的图像。
官方地址:http://yann.lecun.com/exdb/mnist/
有对这个数据的详细介绍。这里提一下,数据他是二进制文件格式存储的。不是图片格式,所以需要注意其数据存放格式,在opencv中进行数据格式转换。
数据格式:
TRAINING SET LABEL FILE (train-labels-idx1-ubyte):
[offset] [type] [value] [description]
0000 32 bitinteger 0x00000801(2049) magic number (MSB first)
0004 32 bitinteger 60000 number of items
0008 unsigned byte ?? label
........
xxxx unsigned byte ?? label
The labels values are 0 to 9.
TRAINING SET IMAGE FILE (train-images-idx3-ubyte):
[offset] [type] [value] [description]
0000 32 bitinteger 0x00000803(2051) magicnumber
0004 32 bitinteger 60000 number of images
0008 32 bitinteger 28 number of rows
0012 32 bitinteger 28 number of columns
0016 unsigned byte ?? pixel
0017 unsigned byte ?? pixel
........
xxxx unsigned byte ?? pixel
SVM的识别错误率:
界面:
环境:opencv2.4+Ubuntu+linux
nistlabledata.h
#ifndef NISTLABLEDATA_H #define NISTLABLEDATA_H #include <opencv2/opencv.hpp> #include "nisttraindata.h" #include "trainsformdata.h" using namespace std; using namespace cv; class NISTLableData:public trainsformdata { public: NISTLableData(); ~NISTLableData(); private: long int magic_number; long int number_of_items; static const long int magic_number_setted= 0x801; //friend long int NISTTrainData::trainsform_32bitDataform(long int &data,unsigned char* char_nums); public: unsigned char magic_numbers[4],number_items[4]; long int getnumber_of_items(); bool check_magic_number(); unsigned char lable; void trainsform_Dataforms() { trainsform_32bitDataform(magic_number,magic_numbers); trainsform_32bitDataform(number_of_items,number_items); } void show_Data() { cout<<"magic_number:"<<magic_number<<endl; cout<<"number_of_items:"<<number_of_items<<endl; } }; #endif // NISTLABLEDATA_H
Nistlabledata.cpp
#include "nistlabledata.h" NISTLableData::NISTLableData() { magic_number=0; number_of_items=0; } NISTLableData::~NISTLableData() { }
Nisttraindata.h
#ifndef NISTTRAINDATA_H #define NISTTRAINDATA_H #include <opencv2/opencv.hpp> #include "trainsformdata.h" using namespace std; using namespace cv; class NISTTrainData:public trainsformdata { public: NISTTrainData(); ~NISTTrainData(); private: long int magic_number; long int number_of_images; long int number_of_rows; long int number_of_columns; static const long int magic_number_setted= 0x803; public: static const int image_row= 20; static const int image_col= 20; unsigned char magicNum[4], ccount[4], crows[4], ccols[4]; void GetROI(Mat& src, Mat& dst); friend long int trainsform_32bitDataform(long int &data,unsigned char* char_nums); long int getnumber_of_images(); long int getrows(); long int getcols(); void trainsform_Dataforms(); void show_Data(); bool check_magic_number(); uchar data[64]; }; #endif // NISTTRAINDATA_H
Nisttraindata.cpp
#include "nisttraindata.h" #include "trainsformdata.h" //#include "trainsformdata.h" NISTTrainData::NISTTrainData() { magic_number = 0; number_of_images = 0; number_of_rows = 0; number_of_columns = 0; } NISTTrainData::~NISTTrainData() { } void NISTTrainData::GetROI(Mat& src, Mat& dst) { int left, right, top, bottom; left = src.cols; right = 0; top = src.rows; bottom = 0; //Get valid area 遍历图像统计区域端点 for(int i=0; i<src.rows; i++) { for(int j=0; j<src.cols; j++) { if(src.at<uchar>(i, j) > 0) { if(j<left) left = j; if(j>right) right = j; if(i<top) top = i; if(i>bottom) bottom = i; } } } Point center; center.x = (left + right) / 2; center.y = (top + bottom) / 2; int width = right - left + 1; int height = bottom - top + 1; int len = (width < height) ? height : width; if(width < height) { left = center.x - height*0.5; right = center.x + height*0.5; } else if(width > height) { top = center.y - width*0.5; bottom = center.y + width*0.5; } // cout<<"roi len:"<<len<<endl; dst.create(len,len,CV_8UC1); for(int i=0; i<dst.rows; i++) for(int j=0; j<dst.cols; j++) { dst.data[i*dst.cols+j] = src.data[(i+top)*src.cols+j+left]; //dst.at<uchar>(i,j) = src.at<uchar>(i+top,j+left); } resize(dst, dst, Size(image_row,image_col)); } long int NISTTrainData::getnumber_of_images() { return number_of_images; } long int NISTTrainData::getrows() { return number_of_rows; } long int NISTTrainData::getcols() { return number_of_columns; } void NISTTrainData::trainsform_Dataforms() { trainsform_32bitDataform(magic_number,magicNum); trainsform_32bitDataform(number_of_images,ccount); trainsform_32bitDataform(number_of_rows,crows); trainsform_32bitDataform(number_of_columns,ccols); } void NISTTrainData::show_Data() { cout<<" magic_number: "<<magic_number<< " number_of_images: "<<number_of_images<< " number_of_rows: "<<number_of_rows<< " number_of_columns: "<<number_of_columns<<endl; } bool NISTTrainData::check_magic_number() { return (magic_number==magic_number_setted); }
Trainformdata.h
#ifndef TRAINSFORMDATA_H #define TRAINSFORMDATA_H class trainsformdata { public: trainsformdata(); ~trainsformdata(); long int trainsform_32bitDataform(long int &data,unsigned char* char_nums) { data+= (((unsigned long int)char_nums[0])<<24); data+= (((unsigned long int)char_nums[1])<<16); data+= (((unsigned long int)char_nums[2])<<8); data+= ((unsigned long int)char_nums[3]); return data; } }; #endif // TRAINSFORMDATA_H
Trainsformdata.c
#include "trainsformdata.h" trainsformdata::trainsformdata() { } trainsformdata::~trainsformdata() { }
Mainwindow.cpp
#include "mainwindow.h" #include "ui_mainwindow.h" #include <opencv2/core/core.hpp> #include <opencv2/highgui/highgui.hpp> #include <opencv2/imgproc/imgproc.hpp> #include <opencv2/ml/ml.hpp> #include <opencv2/opencv.hpp> #include "qdebug.h" #include "nisttraindata.h" #include "nistlabledata.h" #include <fstream> #include <vector> using namespace std; using namespace cv; #define NTRAINING_SAMPLES 100 // 每类训练样本的数量 #define FRAC_LINEAR_SEP 0.9f // 线性可分部分的样本组成比例 struct InputData { unsigned char lable; float data[NISTTrainData::image_row*NISTTrainData::image_col]; }InputData_; vector<InputData> buffer; void MainWindow::on_pushButton_2_clicked()//载入数据 { //Open image and label file NISTTrainData TData; NISTLableData LData; const char fileName[] = "../res/train-images.idx3-ubyte"; const char labelFileName[] = "../res/train-labels.idx1-ubyte"; ifstream lab_ifs(labelFileName, ios_base::binary); ifstream ifs(fileName, ios_base::binary); if( ifs.fail() == true ) { cout<<"train fail"<<endl; return; } if( lab_ifs.fail() == true ) { cout<<"labelFile fail"<<endl; return; } ifs.read((char *)&(TData.magicNum[0]), sizeof(long int)); ifs.read((char *)&(TData.ccount[0]), sizeof(long int)); ifs.read((char *)&(TData.crows[0]), sizeof(long int)); ifs.read((char *)&(TData.ccols[0]), sizeof(long int)); TData.trainsform_Dataforms(); TData.show_Data(); lab_ifs.read((char *)&(LData.magic_numbers),sizeof(long int)); lab_ifs.read((char *)&(LData.number_items),sizeof(long int)); LData.trainsform_Dataforms(); LData.show_Data(); //Just skip label header //lab_ifs.read(magicNum, sizeof(magicNum)); //lab_ifs.read(ccount, sizeof(ccount)); //Create source and show image matrix Mat src = Mat::zeros(28, 28, CV_8UC1); Mat temp = Mat::zeros(8, 8, CV_8UC1); const int total = 2000; int count = 0; Mat roi; while(!ifs.eof()) { if(count >= total||count==TData.getnumber_of_images()) break; count++; ifs.read((char *)(src.data), TData.getcols()*TData.getrows()); TData.GetROI(src,roi); lab_ifs.read((char *)(&(LData.lable)),sizeof(char)); //imshow("1",roi); LData.lable =LData.lable+'0'; cout<<"lable:"<<LData.lable<<endl; //waitKey(10); InputData_.lable = LData.lable; for(int i = 0; i<TData.image_row; i++) { for(int j = 0; j<TData.image_col; j++) { InputData_.data[ i*TData.image_col +j] = roi.at<uchar>(i, j); } } buffer.push_back(InputData_); } cout<<"load trainingdata ok"<<endl; ifs.close(); lab_ifs.close(); cout<<"123\b456"; cout<<"\b"<<endl; std::cout<<"hello\b123"<<std::endl; } void MainWindow::on_train_clicked() { vector<InputData>& trainData = buffer; int testCount = trainData.size(); int featureLen = NISTTrainData::image_col*NISTTrainData::image_row; Mat m = Mat::zeros(1, featureLen, CV_32FC1); Mat data = Mat::zeros(testCount, featureLen, CV_32FC1); Mat res = Mat::zeros(testCount, 1, CV_32SC1); for (int i= 0; i< testCount; i++) { InputData td = trainData.at(i); memcpy(m.data, td.data, featureLen*sizeof(float)); normalize(m, m); memcpy(data.data + i*featureLen*sizeof(float), m.data, featureLen*sizeof(float)); res.at<unsigned int>(i, 0) = td.lable; } // Mat showm = Mat::zeros(20, 20, CV_32FC1); // for(int i =0;i<showm.rows;i++) // for(int j =0;j<showm.cols;j++) // { // showm.at<float>(i,j) = ((InputData)trainData.at(1)).data[i*showm.cols+j]; // } // imshow("sss",showm); CvSVM svm = CvSVM(); CvSVMParams param; CvTermCriteria criteria; criteria= cvTermCriteria(CV_TERMCRIT_EPS, 1000, FLT_EPSILON); param= CvSVMParams(CvSVM::C_SVC, CvSVM::RBF, 10.0, 8.0, 1.0, 10.0, 0.5, 0.1, NULL, criteria); //param= CvSVMParams(CvSVM::C_SVC, CvSVM::LINEAR, 10.0, 8.0, 1.0, 10.0, 0.5, 0.1, NULL, criteria); cout<<"training..."<<endl<<"it takes a long time, please wait!"<<endl; svm.train(data, res, Mat(), Mat(), param); cout<<"training finished..."<<endl; cout<<"saving \"SVM_DATA.xml\"..."<<endl; svm.save( "SVM_DATA.xml" ); cout<<"saved..."<<endl; CvSVM svmpredict = CvSVM(); svmpredict.load( "SVM_DATA.xml" ); InputData td = trainData.at(0); memcpy(m.data, td.data, featureLen*sizeof(float)); normalize(m, m); char ret = (char)svmpredict.predict(m); cout<<"ret is :"<<ret<<endl; cout<<"labble is :"<<td.lable<<endl; } void MainWindow::on_testPredict_clicked() { vector<InputData> Testbuffer; NISTTrainData TData; NISTLableData LData; const char fileName[] = "../res/t10k-images.idx3-ubyte"; const char labelFileName[] = "../res/t10k-labels.idx1-ubyte"; ifstream lab_ifs(labelFileName, ios_base::binary); ifstream ifs(fileName, ios_base::binary); if( ifs.fail() == true ) { cout<<"train fail"<<endl; return; } if( lab_ifs.fail() == true ) { cout<<"labelFile fail"<<endl; return; } ifs.read((char *)&(TData.magicNum[0]), sizeof(long int)); ifs.read((char *)&(TData.ccount[0]), sizeof(long int)); ifs.read((char *)&(TData.crows[0]), sizeof(long int)); ifs.read((char *)&(TData.ccols[0]), sizeof(long int)); TData.trainsform_Dataforms(); TData.show_Data(); lab_ifs.read((char *)&(LData.magic_numbers),sizeof(long int)); lab_ifs.read((char *)&(LData.number_items),sizeof(long int)); LData.trainsform_Dataforms(); LData.show_Data(); //Just skip label header //lab_ifs.read(magicNum, sizeof(magicNum)); //lab_ifs.read(ccount, sizeof(ccount)); //Create source and show image matrix Mat src = Mat::zeros(28, 28, CV_8UC1); Mat temp = Mat::zeros(8, 8, CV_8UC1); int total = 10000; int count = 0; Mat roi; while(!ifs.eof()) { if(count >= total||count==TData.getnumber_of_images()) break; count++; ifs.read((char *)(src.data), TData.getcols()*TData.getrows()); TData.GetROI(src,roi); lab_ifs.read((char *)(&(LData.lable)),sizeof(char)); //imshow("1",roi); LData.lable =LData.lable+'0'; cout<<"lable:"<<LData.lable<<endl; //waitKey(0); InputData_.lable = LData.lable; for(int i = 0; i<TData.image_row; i++) { for(int j = 0; j<TData.image_col; j++) { InputData_.data[ i*TData.image_col +j] = roi.at<uchar>(i, j); } } Testbuffer.push_back(InputData_); } vector<InputData>& trainData = Testbuffer; int testCount = trainData.size(); int featureLen = NISTTrainData::image_col*NISTTrainData::image_row; Mat m = Mat::zeros(1, featureLen, CV_32FC1); cout<<"load trainingdata ok"<<endl; ifs.close(); lab_ifs.close(); CvSVM svmpredict1 = CvSVM(); svmpredict1.load( "SVM_DATA.xml" ); cout<<"testing..."<<endl; int count_test = 0; for(int i = 0; i<testCount; i++) { InputData td = trainData.at(i); memcpy(m.data, td.data, featureLen*sizeof(float)); normalize(m, m); char ret = (char)svmpredict1.predict(m); // cout<<"ret is :"<<ret<<endl; // cout<<"labble is :"<<td.lable<<endl; if(ret == td.lable) { count_test++; } if(i%(testCount/100) == 0) { cout<< abc5 ;i/100<<"%"<<endl; } } cout<<"test finished!"<<endl; cout<<"crect:"<<(count_test*1.0/testCount)*100<<"%"<<endl; cout<<"totall:"<<count_test<<endl; // cout<<"ret is :"<<ret<<endl; // cout<<"labble is :"<<td.lable<<endl; }
其中一个数据(已经被博主归一化了大小):
导入数据的输出提示:
模型训练提示输出:
测试集测试结果:线性核下正确率92.83%,低于上面网站上的的正确率,可能是参数没有设置好。
相关文章推荐
- Linux socket 初步
- 使用C++实现JNI接口需要注意的事项
- linux lsof详解
- linux 文件权限
- Linux 执行数学运算
- 10 篇对初学者和专家都有用的 Linux 命令教程
- Linux 与 Windows 对UNICODE 的处理方式
- Ubuntu12.04下QQ完美走起啊!走起啊!有木有啊!
- 解決Linux下Android开发真机调试设备不被识别问题
- 运维入门
- 运维提升
- Linux 自检和 SystemTap
- Ubuntu Linux使用体验
- c语言实现hashmap(转载)
- Linux 信号signal处理机制
- linux下mysql添加用户
- 关于指针的一些事情
- Scientific Linux 5.5 图形安装教程
- 基于 Linux 集群环境上 GPFS 的问题诊断