您的位置:首页 > 编程语言 > Java开发

[035]Java实现SVM对乳腺癌检测数据分类分析

2016-04-26 21:42 609 查看

背景简介:

最近在做SVM分类的学习,查看网上大多相关内容都是SVM原理介绍、推导和用终端命令行使用svm-train,svm-predict。具体数据分析实现很少。通过查找资料发现了一个很好的开发库LIBSVM。LIBSVM– A Library for Support Vector Machines是由the National Science Council of Taiwan发布维护的,对SVM进行了很好的封装,对数据分析更加方便,更主要它收集了大量的用于分类、回归、对标签的数据集,从数据角度对SVM进行深层次的学习,地址:https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/

官方地址:https://www.csie.ntu.edu.tw/~cjlin/libsvm/

准备训练和测试数据:

在LibSVM官网就可以下载到需要的数据集,本例下载的UCI的breast-cancer数据集,训练样本和测试样本的基本格式如下:

<label> <index1>:<value1> <index2>:<value2>


例如:

4.000000 1:1099510.000000 2:10.000000 3:4.000000 4:3.000000 5:1.000000 6:3.000000 7:3.000000 8:6.000000 9:5.000000 10:2.000000

4.000000 1:1100524.000000 2:6.000000 3:10.000000 4:10.000000 5:2.000000 6:8.000000 7:10.000000 8:7.000000 9:3.000000 10:3.000000

4.000000 1:1102573.000000 2:5.000000 3:6.000000 4:5.000000 5:6.000000 6:10.000000 7:1.000000 8:3.000000 9:1.000000 10:1.000000

链接:https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary.html#breast-cancer

字段含义:

0.Class: (2 for benign, 4 for malignant)

1. Sample code number: id number

2. Clump Thickness: 1 - 10

3. Uniformity of Cell Size: 1 - 10

4. Uniformity of Cell Shape: 1 - 10

5. Marginal Adhesion: 1 - 10

6. Single Epithelial Cell Size: 1 - 10

7. Bare Nuclei: 1 - 10

8. Bland Chromatin: 1 - 10

9. Normal Nucleoli: 1 - 10

10. Mitoses: 1 - 10

项目部署:

建立JAVA工程,导入LibSVM 的JAR包,要注意还需要导入java文件下的svm_train.java、svm_scale.java和svm_predict.java这三个文件,这三个类其实主要在LibSVM基础上做了进一步封装,把命令行参数转化成了String []类型的函数参数,方便API调用。另外一个svm_tony.java是图形界面可以不导入。

将训练和测试数据文件放在工程下,方便调用。

编写JAVA调用LibSVM API分类代码如下:

import java.io.IOException;

import libsvm.*;

/**JAVA test code for LibSVM
* @author yangliu
* @blog http://blog.csdn.net/yangliuy * @mail yangliuyx@gmail.com
*/

public class LibSVMTest {

public static void main(String[] args) throws IOException {
// TODO Auto-generated method stub
//Test for svm_train and svm_predict
//svm_train:
//    param: String[], parse result of command line parameter of svm-train
//    return: String, the directory of modelFile
//svm_predect:
//    param: String[], parse result of command line parameter of svm-predict, including the modelfile
//    return: Double, the accuracy of SVM classification
String[] trainArgs = {"UCI-breast-cancer-tra"};//directory of training file
String modelFile = svm_train.main(trainArgs);
String[] testArgs = {"UCI-breast-cancer-test", modelFile, "UCI-breast-cancer-result"};//directory of test file, model file, result file
Double accuracy = svm_predict.main(testArgs);
System.out.println("SVM Classification is done! The accuracy is " + accuracy);

//Test for cross validation
//String[] crossValidationTrainArgs = {"-v", "10", "UCI-breast-cancer-tra"};// 10 fold cross validation
//modelFile = svm_train.main(crossValidationTrainArgs);
//System.out.print("Cross validation is done! The modelFile is " + modelFile);
}

}


执行结果:

.*
optimization finished, #iter = 1223
nu = 0.6996186233933985
obj = -271.992875483972, rho = 0.4257786283326366
nSV = 639, nBSV = 222
Total nSV = 639
Accuracy = 69.23076923076923% (27/39) (classification)
SVM Classification is done! The accuracy is 0.6923076923076923


可以看到准确率只有0.69

程序改进:

利用svm_scale.java将数据归一化,归一化数据需要单独存储到UCI-breast-cancer-tra-scale和UCI-breast-cancer-test-scale,再次处理。

svm_scale.java需要修改几个地方代码:

output_target函数修改为:

private String output_target(double value)
{
if(y_scaling)
{
if(value == y_min)
value = y_lower;
else if(value == y_max)
value = y_upper;
else
value = y_lower + (y_upper-y_lower) *
(value-y_min) / (y_max-y_min);
}

System.out.print(value + " ");
return value + " ";
}


output函数改为:

private String output(int index, double value)
{
/* skip single-valued attribute */
if(feature_max[index] == feature_min[index])
return " ";

if(value == feature_min[index])
value = lower;
else if(value == feature_max[index])
value = upper;
else
value = lower + (upper-lower) *
(value-feature_min[index])/
(feature_max[index]-feature_min[index]);

if(value != 0)
{
System.out.print(index + ":" + value + " ");
new_num_nonzeros++;
return index + ":" + value + " ";
}
return " ";
}


run需要修改两部分代码:

switch(argv[i-1].charAt(1))
{
case 'l': lower = Double.parseDouble(argv[i]);  break;
case 'u': upper = Double.parseDouble(argv[i]);  break;
case 'y':
y_lower = Double.parseDouble(argv[i]);
++i;
y_upper = Double.parseDouble(argv[i]);
y_scaling = true;
break;
case 's': save_filename = argv[i];  break;
case 'r': restore_filename = argv[i];   break;
case 'p': save_filePath = argv[i];  break;
default:
System.err.println("unknown option");
exit_with_help();
}


BufferedWriter bw = FileStream.fileWriterStream(save_filePath,  true);

/* pass 3: scale */
while(readline(fp) != null)
{
int next_index = 1;
double target;
double value;
String dataLine = "";

StringTokenizer st = new StringTokenizer(line," \t\n\r\f:");
target = Double.parseDouble(st.nextToken());
dataLine = output_target(target);
while(st.hasMoreElements())
{
index = Integer.parseInt(st.nextToken());
value = Double.parseDouble(st.nextToken());
for (i = next_index; i<index; i++)
dataLine += output(i, 0);
dataLine += output(index, value);
next_index = index + 1;
}

for(i=next_index;i<= max_index;i++)
output(i, 0);
System.out.print("\n");
dataLine += "\n";
FileStream.writerData(bw, dataLine);
}
if (new_num_nonzeros > num_nonzeros)
System.err.print(
"WARNING: original #nonzeros " + num_nonzeros+"\n"
+"         new      #nonzeros " + new_num_nonzeros+"\n"
+"Use -l 0 if many original feature values are zeros\n");

fp.close();
bw.close();


新建FileStream 类,用于数据存储

package com.yuan.util;

import java.io.BufferedWriter;
import java.io.FileWriter;
import java.io.IOException;

public class FileStream {

public static BufferedWriter fileWriterStream(String fileName, boolean append){
BufferedWriter fp_save = null;
try {
fp_save = new BufferedWriter(new FileWriter(fileName, append));
} catch(IOException e) {
System.err.println("can't open file " + fileName);
System.exit(1);
}
return fp_save;
}

public static void writerData(BufferedWriter bw, String data) throws IOException{
bw.write(data);
}
}


修改SVMClassifierTest类

// TODO Auto-generated method stub
//Test for svm_train and svm_predict
//svm_train:
//    param: String[], parse result of command line parameter of svm-train
//    return: String, the directory of modelFile
//svm_predect:
//    param: String[], parse result of command line parameter of svm-predict, including the modelfile
//    return: Double, the accuracy of SVM classification
String[] trainArgs = {"UCI-breast-cancer-tra"};//directory of training file
svm_scale.main(new String[]{"-p", "UCI-breast-cancer-tra-scale", "UCI-breast-cancer-tra"});//训练数据归一化存储
svm_scale.main(new String[]{"-p", "UCI-breast-cancer-test-scale", "UCI-breast-cancer-test"});//测试数据归一化存储

String[] scaleTrainArgs = {"UCI-breast-cancer-tra-scale"};//directory of training file
String modelFile = svm_train.main(scaleTrainArgs);

String[] testArgs = {"UCI-breast-cancer-test-scale", modelFile, "UCI-breast-cancer-result"};//directory of test file, model file, result file
Double accuracy = svm_predict.main(testArgs);
System.out.println("SVM Classification is done! The accuracy is " + accuracy);

//Test for cross validation
//String[] crossValidationTrainArgs = {"-v", "10", "UCI-breast-cancer-tra"};// 10 fold cross validation
//modelFile = svm_train.main(crossValidationTrainArgs);
//System.out.print("Cross validation is done! The modelFile is " + modelFile);


结果:

*
optimization finished, #iter = 97
nu = 0.0711047842614367
obj = -78.46733678185721, rho = -0.9253740588830286
nSV = 99, nBSV = 83
Total nSV = 99
Accuracy = 89.74358974358975% (70/78) (classification)
SVM Classification is done! The accuracy is 0.8974358974358975


可以看到准确率大幅度提高。

至此LIBSVM的简单调用及改进就完成了。

引用:

Chih-Chung Chang and Chih-Jen Lin, LIBSVM : a library for support vector machines. ACM Transactions on Intelligent Systems and Technology, 2:27:1–27:27, 2011. Software available at http://www.csie.ntu.edu.tw/~cjlin/libsvm.
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: