k-均值算法的java实现
2010-12-24 22:33
302 查看
import java.io.BufferedReader;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
public class KAverage {
private int sampleCount = 0;
private int dimensionCount = 0;
private int centerCount = 0;
private double[][] sampleValues;
private double[][] centers;
private double[][] tmpCenters;
private String dataFile = "";
/**
* 通过构造器传人数据文件
*/
public KAverage(String dataFile) throws NumberInvalieException {
this.dataFile = dataFile;
}
/**
* 第一行为s;d;c含义分别为样例的数目,每个样例特征的维数,聚类中心个数 文件格式为d[,d]...;d[,d]... 如:1,2;2,3;1,5
* 每一维之间用,隔开,每个样例间用;隔开。结尾没有';' 可以有多行
*/
private int initData(String fileName) {
String line;
String samplesValue[];
String dimensionsValue[] = new String[dimensionCount];
BufferedReader in;
try {
in = new BufferedReader(new FileReader(fileName));
} catch (FileNotFoundException e) {
e.printStackTrace();
return -1;
}
/*
* 预处理样本,允许后面几维为0时,不写入文件
*/
for (int i = 0; i < sampleCount; i++) {
for (int j = 0; j < dimensionCount; j++) {
sampleValues[i][j] = 0;
}
}
int i = 0;
double tmpValue = 0.0;
try {
line = in.readLine();
String params[] = line.split(";");
if (params.length != 3) {// 必须为3个参数,否则错误
return -1;
}
/**
* 获取参数
*/
this.sampleCount = Integer.parseInt(params[0]);
this.dimensionCount = Integer.parseInt(params[1]);
this.centerCount = Integer.parseInt(params[2]);
if (sampleCount <= 0 || dimensionCount <= 0 || centerCount <= 0) {
throw new NumberInvalieException("input number <= 0.");
}
if (sampleCount < centerCount) {
throw new NumberInvalieException(
"sample number < center number");
}
sampleValues = new double[sampleCount][dimensionCount + 1];
centers = new double[centerCount][dimensionCount];
tmpCenters = new double[centerCount][dimensionCount];
while ((line = in.readLine()) != null) {
samplesValue = line.split(";");
for (int j = 0; j < samplesValue.length; j++) {
dimensionsValue = samplesValue[j].split(",");
for (int k = 0; k < dimensionsValue.length; k++) {
tmpValue = Double.parseDouble(dimensionsValue[k]);
sampleValues[i][k] = tmpValue;
}
i++;
}
}
} catch (IOException e) {
e.printStackTrace();
return -2;
} catch (Exception e) {
e.printStackTrace();
return -3;
}
return 1;
}
/**
* 返回样本中第s1个和第s2个间的欧式距离
*/
private double getDistance(int s1, int s2) throws NumberInvalieException {
double distance = 0.0;
if (s1 < 0 || s1 >= sampleCount || s2 < 0 || s2 >= sampleCount) {
throw new NumberInvalieException("number out of bound.");
}
for (int i = 0; i < dimensionCount; i++) {
distance += (sampleValues[s1][i] - sampleValues[s2][i])
* (sampleValues[s1][i] - sampleValues[s2][i]);
}
return distance;
}
/**
* 返回给定两个向量间的欧式距离
*/
private double getDistance(double s1[], double s2[]) {
double distance = 0.0;
for (int i = 0; i < dimensionCount; i++) {
distance += (s1[i] - s2[i]) * (s1[i] - s2[i]);
}
return distance;
}
/**
* 更新样本中第s个样本的最近中心
*/
private int getNearestCenter(int s) {
int center = 0;
double minDistance = Double.MAX_VALUE;
double distance = 0.0;
for (int i = 0; i < centerCount; i++) {
distance = getDistance(sampleValues[s], centers[i]);
if (distance < minDistance) {
minDistance = distance;
center = i;
}
}
sampleValues[s][dimensionCount] = center;
return center;
}
/**
* 更新所有中心
*/
private void updateCenters() {
double center[] = new double[dimensionCount];
for (int i = 0; i < dimensionCount; i++) {
center[i] = 0;
}
int count = 0;
for (int i = 0; i < centerCount; i++) {
count = 0;
for (int j = 0; j < sampleCount; j++) {
if (sampleValues[j][dimensionCount] == i) {
count++;
for (int k = 0; k < dimensionCount; k++) {
center[k] += sampleValues[j][k];
}
}
}
for (int j = 0; j < dimensionCount; j++) {
centers[i][j] = center[j] / count;
}
}
}
/**
* 判断算法是否终止
*/
private boolean toBeContinued() {
for (int i = 0; i < centerCount; i++) {
for (int j = 0; j < dimensionCount; j++) {
if (tmpCenters[i][j] != centers[i][j]) {
return true;
}
}
}
return false;
}
/**
* 关键方法,调用其他方法,处理数据
*/
public void doCaculate() {
initData(dataFile);
for (int i = 0; i < centerCount; i++) {
for (int j = 0; j < dimensionCount; j++) {
centers[i][j] = sampleValues[i][j];
}
}
for (int i = 0; i < centerCount; i++) {
for (int j = 0; j < dimensionCount; j++) {
tmpCenters[i][j] = 0;
}
}
while (toBeContinued()) {
for (int i = 0; i < sampleCount; i++) {
getNearestCenter(i);
}
for (int i = 0; i < centerCount; i++) {
for (int j = 0; j < dimensionCount; j++) {
tmpCenters[i][j] = centers[i][j];
}
}
updateCenters();
System.out
.println("******************************************************");
showResultData();
}
}
/*
* 显示数据
*/
private void showSampleData() {
for (int i = 0; i < sampleCount; i++) {
for (int j = 0; j < dimensionCount; j++) {
if (j == 0) {
System.out.print(sampleValues[i][j]);
} else {
System.out.print("," + sampleValues[i][j]);
}
}
System.out.println();
}
}
/*
* 分组显示结果
*/
private void showResultData() {
for (int i = 0; i < centerCount; i++) {
System.out.println("第" + (i + 1) + "个分组内容为:");
for (int j = 0; j < sampleCount; j++) {
if (sampleValues[j][dimensionCount] == i) {
for (int k = 0; k <= dimensionCount; k++) {
if (k == 0) {
System.out.print(sampleValues[j][k]);
} else {
System.out.print("," + sampleValues[j][k]);
}
}
System.out.println();
}
}
}
}
public static void main(String[] args) {
/*
*也可以通过命令行得到参数
*/
String fileName = "D://eclipsejava//K-Average//src//sample.txt";
if(args.length > 0){
fileName = args[0];
}
try {
KAverage ka = new KAverage(fileName);
ka.doCaculate();
System.out
.println("***************************<<result>>**************************");
ka.showResultData();
} catch (Exception e) {
e.printStackTrace();
}
}
}
Java代码
/*
* 根据自己的需要定义一些异常,使得系统性更强
*/
public class NumberInvalieException extends Exception {
private String cause;
public NumberInvalieException(String cause){
if(cause == null || "".equals(cause)){
this.cause = "unknow";
}else{
this.cause = cause;
}
}
@Override
public String toString() {
return "Number Invalie!Cause by " + cause;
}
}
测试数据
20;2;4
0,0;1,0;0,1;1,1;2,1;1,2;2,2;3,2;6,6;7,6
8,6;6,7;7,7;8,7;9,7;7,8;8,8;9,8;8,9;9,9
测试结果
***************************<<result>>**************************
第1个分组内容为:
0.0,0.0,0.0
1.0,0.0,0.0
0.0,1.0,0.0
1.0,1.0,0.0
2.0,1.0,0.0
1.0,2.0,0.0
2.0,2.0,0.0
3.0,2.0,0.0
第2个分组内容为:
6.0,6.0,1.0
7.0,6.0,1.0
8.0,6.0,1.0
6.0,7.0,1.0
7.0,7.0,1.0
8.0,7.0,1.0
9.0,7.0,1.0
7.0,8.0,1.0
8.0,8.0,1.0
9.0,8.0,1.0
8.0,9.0,1.0
9.0,9.0,1.0
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
public class KAverage {
private int sampleCount = 0;
private int dimensionCount = 0;
private int centerCount = 0;
private double[][] sampleValues;
private double[][] centers;
private double[][] tmpCenters;
private String dataFile = "";
/**
* 通过构造器传人数据文件
*/
public KAverage(String dataFile) throws NumberInvalieException {
this.dataFile = dataFile;
}
/**
* 第一行为s;d;c含义分别为样例的数目,每个样例特征的维数,聚类中心个数 文件格式为d[,d]...;d[,d]... 如:1,2;2,3;1,5
* 每一维之间用,隔开,每个样例间用;隔开。结尾没有';' 可以有多行
*/
private int initData(String fileName) {
String line;
String samplesValue[];
String dimensionsValue[] = new String[dimensionCount];
BufferedReader in;
try {
in = new BufferedReader(new FileReader(fileName));
} catch (FileNotFoundException e) {
e.printStackTrace();
return -1;
}
/*
* 预处理样本,允许后面几维为0时,不写入文件
*/
for (int i = 0; i < sampleCount; i++) {
for (int j = 0; j < dimensionCount; j++) {
sampleValues[i][j] = 0;
}
}
int i = 0;
double tmpValue = 0.0;
try {
line = in.readLine();
String params[] = line.split(";");
if (params.length != 3) {// 必须为3个参数,否则错误
return -1;
}
/**
* 获取参数
*/
this.sampleCount = Integer.parseInt(params[0]);
this.dimensionCount = Integer.parseInt(params[1]);
this.centerCount = Integer.parseInt(params[2]);
if (sampleCount <= 0 || dimensionCount <= 0 || centerCount <= 0) {
throw new NumberInvalieException("input number <= 0.");
}
if (sampleCount < centerCount) {
throw new NumberInvalieException(
"sample number < center number");
}
sampleValues = new double[sampleCount][dimensionCount + 1];
centers = new double[centerCount][dimensionCount];
tmpCenters = new double[centerCount][dimensionCount];
while ((line = in.readLine()) != null) {
samplesValue = line.split(";");
for (int j = 0; j < samplesValue.length; j++) {
dimensionsValue = samplesValue[j].split(",");
for (int k = 0; k < dimensionsValue.length; k++) {
tmpValue = Double.parseDouble(dimensionsValue[k]);
sampleValues[i][k] = tmpValue;
}
i++;
}
}
} catch (IOException e) {
e.printStackTrace();
return -2;
} catch (Exception e) {
e.printStackTrace();
return -3;
}
return 1;
}
/**
* 返回样本中第s1个和第s2个间的欧式距离
*/
private double getDistance(int s1, int s2) throws NumberInvalieException {
double distance = 0.0;
if (s1 < 0 || s1 >= sampleCount || s2 < 0 || s2 >= sampleCount) {
throw new NumberInvalieException("number out of bound.");
}
for (int i = 0; i < dimensionCount; i++) {
distance += (sampleValues[s1][i] - sampleValues[s2][i])
* (sampleValues[s1][i] - sampleValues[s2][i]);
}
return distance;
}
/**
* 返回给定两个向量间的欧式距离
*/
private double getDistance(double s1[], double s2[]) {
double distance = 0.0;
for (int i = 0; i < dimensionCount; i++) {
distance += (s1[i] - s2[i]) * (s1[i] - s2[i]);
}
return distance;
}
/**
* 更新样本中第s个样本的最近中心
*/
private int getNearestCenter(int s) {
int center = 0;
double minDistance = Double.MAX_VALUE;
double distance = 0.0;
for (int i = 0; i < centerCount; i++) {
distance = getDistance(sampleValues[s], centers[i]);
if (distance < minDistance) {
minDistance = distance;
center = i;
}
}
sampleValues[s][dimensionCount] = center;
return center;
}
/**
* 更新所有中心
*/
private void updateCenters() {
double center[] = new double[dimensionCount];
for (int i = 0; i < dimensionCount; i++) {
center[i] = 0;
}
int count = 0;
for (int i = 0; i < centerCount; i++) {
count = 0;
for (int j = 0; j < sampleCount; j++) {
if (sampleValues[j][dimensionCount] == i) {
count++;
for (int k = 0; k < dimensionCount; k++) {
center[k] += sampleValues[j][k];
}
}
}
for (int j = 0; j < dimensionCount; j++) {
centers[i][j] = center[j] / count;
}
}
}
/**
* 判断算法是否终止
*/
private boolean toBeContinued() {
for (int i = 0; i < centerCount; i++) {
for (int j = 0; j < dimensionCount; j++) {
if (tmpCenters[i][j] != centers[i][j]) {
return true;
}
}
}
return false;
}
/**
* 关键方法,调用其他方法,处理数据
*/
public void doCaculate() {
initData(dataFile);
for (int i = 0; i < centerCount; i++) {
for (int j = 0; j < dimensionCount; j++) {
centers[i][j] = sampleValues[i][j];
}
}
for (int i = 0; i < centerCount; i++) {
for (int j = 0; j < dimensionCount; j++) {
tmpCenters[i][j] = 0;
}
}
while (toBeContinued()) {
for (int i = 0; i < sampleCount; i++) {
getNearestCenter(i);
}
for (int i = 0; i < centerCount; i++) {
for (int j = 0; j < dimensionCount; j++) {
tmpCenters[i][j] = centers[i][j];
}
}
updateCenters();
System.out
.println("******************************************************");
showResultData();
}
}
/*
* 显示数据
*/
private void showSampleData() {
for (int i = 0; i < sampleCount; i++) {
for (int j = 0; j < dimensionCount; j++) {
if (j == 0) {
System.out.print(sampleValues[i][j]);
} else {
System.out.print("," + sampleValues[i][j]);
}
}
System.out.println();
}
}
/*
* 分组显示结果
*/
private void showResultData() {
for (int i = 0; i < centerCount; i++) {
System.out.println("第" + (i + 1) + "个分组内容为:");
for (int j = 0; j < sampleCount; j++) {
if (sampleValues[j][dimensionCount] == i) {
for (int k = 0; k <= dimensionCount; k++) {
if (k == 0) {
System.out.print(sampleValues[j][k]);
} else {
System.out.print("," + sampleValues[j][k]);
}
}
System.out.println();
}
}
}
}
public static void main(String[] args) {
/*
*也可以通过命令行得到参数
*/
String fileName = "D://eclipsejava//K-Average//src//sample.txt";
if(args.length > 0){
fileName = args[0];
}
try {
KAverage ka = new KAverage(fileName);
ka.doCaculate();
System.out
.println("***************************<<result>>**************************");
ka.showResultData();
} catch (Exception e) {
e.printStackTrace();
}
}
}
import java.io.BufferedReader; import java.io.FileNotFoundException; import java.io.FileReader; import java.io.IOException; public class KAverage { private int sampleCount = 0; private int dimensionCount = 0; private int centerCount = 0; private double[][] sampleValues; private double[][] centers; private double[][] tmpCenters; private String dataFile = ""; /** * 通过构造器传人数据文件 */ public KAverage(String dataFile) throws NumberInvalieException { this.dataFile = dataFile; } /** * 第一行为s;d;c含义分别为样例的数目,每个样例特征的维数,聚类中心个数 文件格式为d[,d]...;d[,d]... 如:1,2;2,3;1,5 * 每一维之间用,隔开,每个样例间用;隔开。结尾没有';' 可以有多行 */ private int initData(String fileName) { String line; String samplesValue[]; String dimensionsValue[] = new String[dimensionCount]; BufferedReader in; try { in = new BufferedReader(new FileReader(fileName)); } catch (FileNotFoundException e) { e.printStackTrace(); return -1; } /* * 预处理样本,允许后面几维为0时,不写入文件 */ for (int i = 0; i < sampleCount; i++) { for (int j = 0; j < dimensionCount; j++) { sampleValues[i][j] = 0; } } int i = 0; double tmpValue = 0.0; try { line = in.readLine(); String params[] = line.split(";"); if (params.length != 3) {// 必须为3个参数,否则错误 return -1; } /** * 获取参数 */ this.sampleCount = Integer.parseInt(params[0]); this.dimensionCount = Integer.parseInt(params[1]); this.centerCount = Integer.parseInt(params[2]); if (sampleCount <= 0 || dimensionCount <= 0 || centerCount <= 0) { throw new NumberInvalieException("input number <= 0."); } if (sampleCount < centerCount) { throw new NumberInvalieException( "sample number < center number"); } sampleValues = new double[sampleCount][dimensionCount + 1]; centers = new double[centerCount][dimensionCount]; tmpCenters = new double[centerCount][dimensionCount]; while ((line = in.readLine()) != null) { samplesValue = line.split(";"); for (int j = 0; j < samplesValue.length; j++) { dimensionsValue = samplesValue[j].split(","); for (int k = 0; k < dimensionsValue.length; k++) { tmpValue = Double.parseDouble(dimensionsValue[k]); sampleValues[i][k] = tmpValue; } i++; } } } catch (IOException e) { e.printStackTrace(); return -2; } catch (Exception e) { e.printStackTrace(); return -3; } return 1; } /** * 返回样本中第s1个和第s2个间的欧式距离 */ private double getDistance(int s1, int s2) throws NumberInvalieException { double distance = 0.0; if (s1 < 0 || s1 >= sampleCount || s2 < 0 || s2 >= sampleCount) { throw new NumberInvalieException("number out of bound."); } for (int i = 0; i < dimensionCount; i++) { distance += (sampleValues[s1][i] - sampleValues[s2][i]) * (sampleValues[s1][i] - sampleValues[s2][i]); } return distance; } /** * 返回给定两个向量间的欧式距离 */ private double getDistance(double s1[], double s2[]) { double distance = 0.0; for (int i = 0; i < dimensionCount; i++) { distance += (s1[i] - s2[i]) * (s1[i] - s2[i]); } return distance; } /** * 更新样本中第s个样本的最近中心 */ private int getNearestCenter(int s) { int center = 0; double minDistance = Double.MAX_VALUE; double distance = 0.0; for (int i = 0; i < centerCount; i++) { distance = getDistance(sampleValues[s], centers[i]); if (distance < minDistance) { minDistance = distance; center = i; } } sampleValues[s][dimensionCount] = center; return center; } /** * 更新所有中心 */ private void updateCenters() { double center[] = new double[dimensionCount]; for (int i = 0; i < dimensionCount; i++) { center[i] = 0; } int count = 0; for (int i = 0; i < centerCount; i++) { count = 0; for (int j = 0; j < sampleCount; j++) { if (sampleValues[j][dimensionCount] == i) { count++; for (int k = 0; k < dimensionCount; k++) { center[k] += sampleValues[j][k]; } } } for (int j = 0; j < dimensionCount; j++) { centers[i][j] = center[j] / count; } } } /** * 判断算法是否终止 */ private boolean toBeContinued() { for (int i = 0; i < centerCount; i++) { for (int j = 0; j < dimensionCount; j++) { if (tmpCenters[i][j] != centers[i][j]) { return true; } } } return false; } /** * 关键方法,调用其他方法,处理数据 */ public void doCaculate() { initData(dataFile); for (int i = 0; i < centerCount; i++) { for (int j = 0; j < dimensionCount; j++) { centers[i][j] = sampleValues[i][j]; } } for (int i = 0; i < centerCount; i++) { for (int j = 0; j < dimensionCount; j++) { tmpCenters[i][j] = 0; } } while (toBeContinued()) { for (int i = 0; i < sampleCount; i++) { getNearestCenter(i); } for (int i = 0; i < centerCount; i++) { for (int j = 0; j < dimensionCount; j++) { tmpCenters[i][j] = centers[i][j]; } } updateCenters(); System.out .println("******************************************************"); showResultData(); } } /* * 显示数据 */ private void showSampleData() { for (int i = 0; i < sampleCount; i++) { for (int j = 0; j < dimensionCount; j++) { if (j == 0) { System.out.print(sampleValues[i][j]); } else { System.out.print("," + sampleValues[i][j]); } } System.out.println(); } } /* * 分组显示结果 */ private void showResultData() { for (int i = 0; i < centerCount; i++) { System.out.println("第" + (i + 1) + "个分组内容为:"); for (int j = 0; j < sampleCount; j++) { if (sampleValues[j][dimensionCount] == i) { for (int k = 0; k <= dimensionCount; k++) { if (k == 0) { System.out.print(sampleValues[j][k]); } else { System.out.print("," + sampleValues[j][k]); } } System.out.println(); } } } } public static void main(String[] args) { /* *也可以通过命令行得到参数 */ String fileName = "D://eclipsejava//K-Average//src//sample.txt"; if(args.length > 0){ fileName = args[0]; } try { KAverage ka = new KAverage(fileName); ka.doCaculate(); System.out .println("***************************<<result>>**************************"); ka.showResultData(); } catch (Exception e) { e.printStackTrace(); } } }
Java代码
/*
* 根据自己的需要定义一些异常,使得系统性更强
*/
public class NumberInvalieException extends Exception {
private String cause;
public NumberInvalieException(String cause){
if(cause == null || "".equals(cause)){
this.cause = "unknow";
}else{
this.cause = cause;
}
}
@Override
public String toString() {
return "Number Invalie!Cause by " + cause;
}
}
/* * 根据自己的需要定义一些异常,使得系统性更强 */ public class NumberInvalieException extends Exception { private String cause; public NumberInvalieException(String cause){ if(cause == null || "".equals(cause)){ this.cause = "unknow"; }else{ this.cause = cause; } } @Override public String toString() { return "Number Invalie!Cause by " + cause; } }
测试数据
20;2;4
0,0;1,0;0,1;1,1;2,1;1,2;2,2;3,2;6,6;7,6
8,6;6,7;7,7;8,7;9,7;7,8;8,8;9,8;8,9;9,9
测试结果
***************************<<result>>**************************
第1个分组内容为:
0.0,0.0,0.0
1.0,0.0,0.0
0.0,1.0,0.0
1.0,1.0,0.0
2.0,1.0,0.0
1.0,2.0,0.0
2.0,2.0,0.0
3.0,2.0,0.0
第2个分组内容为:
6.0,6.0,1.0
7.0,6.0,1.0
8.0,6.0,1.0
6.0,7.0,1.0
7.0,7.0,1.0
8.0,7.0,1.0
9.0,7.0,1.0
7.0,8.0,1.0
8.0,8.0,1.0
9.0,8.0,1.0
8.0,9.0,1.0
9.0,9.0,1.0
相关文章推荐
- 数据算法基于FPGA的图像处理(七)--Verilog实现均值滤波Strut2教程-java教程
- java:均值哈希实现图像内容相似度比较(图像视频相似度算法)
- 机器学习入门算法及其java实现-Kmeans(K均值)算法
- 扑克牌 洗牌算法 的java实现
- 算法代码实现之选出第k小元素、中位数、最小的k个元素(线性复杂度),Java实现
- MergeSort(归并排序)算法Java实现
- java 使用二叉堆实现 TopK 算法
- Newsgroup18828文本分类器、文本聚类器、关联分析频繁模式挖掘算法的Java实现工程下载及运行FAQ
- 几个比较经典的算法问题的java实现
- 【算法数据结构Java实现】折半查找
- java 高效率的排列组合算法(java实现)
- java 算法实现字符串的匹配
- 排序基础算法汇总-java实现
- bagging算法java实现(从N个样本中有放回地取N次)
- Java实现 字符串匹配 KMP 算法
- 算法与数据结构-二叉树 讲解与java代码实现
- Java DFA算法实现敏感词过滤
- 【密码学】RSA加解密原理及其Java实现算法
- 数据挖掘-基于贝叶斯算法及KNN算法的newsgroup18828文档分类器的JAVA实现(下)
- java实现 阿拉伯数字转换为汉字数字 算法