[035]Java实现SVM对乳腺癌检测数据分类分析
2016-04-26 21:42
609 查看
背景简介:
最近在做SVM分类的学习,查看网上大多相关内容都是SVM原理介绍、推导和用终端命令行使用svm-train,svm-predict。具体数据分析实现很少。通过查找资料发现了一个很好的开发库LIBSVM。LIBSVM– A Library for Support Vector Machines是由the National Science Council of Taiwan发布维护的,对SVM进行了很好的封装,对数据分析更加方便,更主要它收集了大量的用于分类、回归、对标签的数据集,从数据角度对SVM进行深层次的学习,地址:https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/。官方地址:https://www.csie.ntu.edu.tw/~cjlin/libsvm/ 。
准备训练和测试数据:
在LibSVM官网就可以下载到需要的数据集,本例下载的UCI的breast-cancer数据集,训练样本和测试样本的基本格式如下:<label> <index1>:<value1> <index2>:<value2>
例如:
4.000000 1:1099510.000000 2:10.000000 3:4.000000 4:3.000000 5:1.000000 6:3.000000 7:3.000000 8:6.000000 9:5.000000 10:2.000000
4.000000 1:1100524.000000 2:6.000000 3:10.000000 4:10.000000 5:2.000000 6:8.000000 7:10.000000 8:7.000000 9:3.000000 10:3.000000
4.000000 1:1102573.000000 2:5.000000 3:6.000000 4:5.000000 5:6.000000 6:10.000000 7:1.000000 8:3.000000 9:1.000000 10:1.000000
链接:https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary.html#breast-cancer
字段含义:
0.Class: (2 for benign, 4 for malignant)
1. Sample code number: id number
2. Clump Thickness: 1 - 10
3. Uniformity of Cell Size: 1 - 10
4. Uniformity of Cell Shape: 1 - 10
5. Marginal Adhesion: 1 - 10
6. Single Epithelial Cell Size: 1 - 10
7. Bare Nuclei: 1 - 10
8. Bland Chromatin: 1 - 10
9. Normal Nucleoli: 1 - 10
10. Mitoses: 1 - 10
项目部署:
建立JAVA工程,导入LibSVM 的JAR包,要注意还需要导入java文件下的svm_train.java、svm_scale.java和svm_predict.java这三个文件,这三个类其实主要在LibSVM基础上做了进一步封装,把命令行参数转化成了String []类型的函数参数,方便API调用。另外一个svm_tony.java是图形界面可以不导入。将训练和测试数据文件放在工程下,方便调用。
编写JAVA调用LibSVM API分类代码如下:
import java.io.IOException; import libsvm.*; /**JAVA test code for LibSVM * @author yangliu * @blog http://blog.csdn.net/yangliuy * @mail yangliuyx@gmail.com */ public class LibSVMTest { public static void main(String[] args) throws IOException { // TODO Auto-generated method stub //Test for svm_train and svm_predict //svm_train: // param: String[], parse result of command line parameter of svm-train // return: String, the directory of modelFile //svm_predect: // param: String[], parse result of command line parameter of svm-predict, including the modelfile // return: Double, the accuracy of SVM classification String[] trainArgs = {"UCI-breast-cancer-tra"};//directory of training file String modelFile = svm_train.main(trainArgs); String[] testArgs = {"UCI-breast-cancer-test", modelFile, "UCI-breast-cancer-result"};//directory of test file, model file, result file Double accuracy = svm_predict.main(testArgs); System.out.println("SVM Classification is done! The accuracy is " + accuracy); //Test for cross validation //String[] crossValidationTrainArgs = {"-v", "10", "UCI-breast-cancer-tra"};// 10 fold cross validation //modelFile = svm_train.main(crossValidationTrainArgs); //System.out.print("Cross validation is done! The modelFile is " + modelFile); } }
执行结果:
.* optimization finished, #iter = 1223 nu = 0.6996186233933985 obj = -271.992875483972, rho = 0.4257786283326366 nSV = 639, nBSV = 222 Total nSV = 639 Accuracy = 69.23076923076923% (27/39) (classification) SVM Classification is done! The accuracy is 0.6923076923076923
可以看到准确率只有0.69
程序改进:
利用svm_scale.java将数据归一化,归一化数据需要单独存储到UCI-breast-cancer-tra-scale和UCI-breast-cancer-test-scale,再次处理。svm_scale.java需要修改几个地方代码:
output_target函数修改为:
private String output_target(double value) { if(y_scaling) { if(value == y_min) value = y_lower; else if(value == y_max) value = y_upper; else value = y_lower + (y_upper-y_lower) * (value-y_min) / (y_max-y_min); } System.out.print(value + " "); return value + " "; }
output函数改为:
private String output(int index, double value) { /* skip single-valued attribute */ if(feature_max[index] == feature_min[index]) return " "; if(value == feature_min[index]) value = lower; else if(value == feature_max[index]) value = upper; else value = lower + (upper-lower) * (value-feature_min[index])/ (feature_max[index]-feature_min[index]); if(value != 0) { System.out.print(index + ":" + value + " "); new_num_nonzeros++; return index + ":" + value + " "; } return " "; }
run需要修改两部分代码:
switch(argv[i-1].charAt(1)) { case 'l': lower = Double.parseDouble(argv[i]); break; case 'u': upper = Double.parseDouble(argv[i]); break; case 'y': y_lower = Double.parseDouble(argv[i]); ++i; y_upper = Double.parseDouble(argv[i]); y_scaling = true; break; case 's': save_filename = argv[i]; break; case 'r': restore_filename = argv[i]; break; case 'p': save_filePath = argv[i]; break; default: System.err.println("unknown option"); exit_with_help(); }
BufferedWriter bw = FileStream.fileWriterStream(save_filePath, true); /* pass 3: scale */ while(readline(fp) != null) { int next_index = 1; double target; double value; String dataLine = ""; StringTokenizer st = new StringTokenizer(line," \t\n\r\f:"); target = Double.parseDouble(st.nextToken()); dataLine = output_target(target); while(st.hasMoreElements()) { index = Integer.parseInt(st.nextToken()); value = Double.parseDouble(st.nextToken()); for (i = next_index; i<index; i++) dataLine += output(i, 0); dataLine += output(index, value); next_index = index + 1; } for(i=next_index;i<= max_index;i++) output(i, 0); System.out.print("\n"); dataLine += "\n"; FileStream.writerData(bw, dataLine); } if (new_num_nonzeros > num_nonzeros) System.err.print( "WARNING: original #nonzeros " + num_nonzeros+"\n" +" new #nonzeros " + new_num_nonzeros+"\n" +"Use -l 0 if many original feature values are zeros\n"); fp.close(); bw.close();
新建FileStream 类,用于数据存储
package com.yuan.util; import java.io.BufferedWriter; import java.io.FileWriter; import java.io.IOException; public class FileStream { public static BufferedWriter fileWriterStream(String fileName, boolean append){ BufferedWriter fp_save = null; try { fp_save = new BufferedWriter(new FileWriter(fileName, append)); } catch(IOException e) { System.err.println("can't open file " + fileName); System.exit(1); } return fp_save; } public static void writerData(BufferedWriter bw, String data) throws IOException{ bw.write(data); } }
修改SVMClassifierTest类
// TODO Auto-generated method stub //Test for svm_train and svm_predict //svm_train: // param: String[], parse result of command line parameter of svm-train // return: String, the directory of modelFile //svm_predect: // param: String[], parse result of command line parameter of svm-predict, including the modelfile // return: Double, the accuracy of SVM classification String[] trainArgs = {"UCI-breast-cancer-tra"};//directory of training file svm_scale.main(new String[]{"-p", "UCI-breast-cancer-tra-scale", "UCI-breast-cancer-tra"});//训练数据归一化存储 svm_scale.main(new String[]{"-p", "UCI-breast-cancer-test-scale", "UCI-breast-cancer-test"});//测试数据归一化存储 String[] scaleTrainArgs = {"UCI-breast-cancer-tra-scale"};//directory of training file String modelFile = svm_train.main(scaleTrainArgs); String[] testArgs = {"UCI-breast-cancer-test-scale", modelFile, "UCI-breast-cancer-result"};//directory of test file, model file, result file Double accuracy = svm_predict.main(testArgs); System.out.println("SVM Classification is done! The accuracy is " + accuracy); //Test for cross validation //String[] crossValidationTrainArgs = {"-v", "10", "UCI-breast-cancer-tra"};// 10 fold cross validation //modelFile = svm_train.main(crossValidationTrainArgs); //System.out.print("Cross validation is done! The modelFile is " + modelFile);
结果:
* optimization finished, #iter = 97 nu = 0.0711047842614367 obj = -78.46733678185721, rho = -0.9253740588830286 nSV = 99, nBSV = 83 Total nSV = 99 Accuracy = 89.74358974358975% (70/78) (classification) SVM Classification is done! The accuracy is 0.8974358974358975
可以看到准确率大幅度提高。
至此LIBSVM的简单调用及改进就完成了。
引用:
Chih-Chung Chang and Chih-Jen Lin, LIBSVM : a library for support vector machines. ACM Transactions on Intelligent Systems and Technology, 2:27:1–27:27, 2011. Software available at http://www.csie.ntu.edu.tw/~cjlin/libsvm.相关文章推荐
- activemq+spring 持久化发送消息
- 我对java的理解(二)——反射是小偷的万能钥匙
- java环境变量设置
- 文件过滤<FilenameFilter>
- JAVA——多线程之单例模式
- 【SSH】Java之静态代理和动态代理
- Spring Cache抽象详解
- Java 数据库操作
- 【排序算法】插入排序原理及Java实现
- JDK源码分析之集合03LinkedList
- java异常处理机制1
- 【排序算法】选择排序原理及Java实现
- java---打印流
- 【排序算法】冒泡排序原理及Java实现
- JAVA类的静态加载和动态加载以及NoClassDefFoundError和ClassNotFoundException
- SpringMVC讲解
- JAVA学习17_Java时区转换及时间格式
- java---转换流2(InputStreamReader和OutputStreamWriter)
- java泛型(三)、通配符的使用
- 一、Spring-boot设置restful