您的位置:首页 > 编程语言 > Java开发

K-means算法(Java实现)

2016-04-26 17:36 471 查看
K-means算法是聚类中最简单的算法,在Weka、R、Matlab、SQL Server Business Intelligence Development Studio均可快速调用。

K-means算法原理



这里给出java实现的具体代码

public class Cluster {
private double[][] cluster;
private int size;//用来记录此簇中有多少条数据
public double[][] getCluster() {
return cluster;
}
public void setCluster(double[][] cluster) {
this.cluster = cluster;
}
public int getSize() {
return size;
}
public void setSize(int size) {
this.size = size;
}
public Cluster(int dataNum,int dimNum) {
super();
this.cluster = new double[dataNum][dimNum];
this.size=0;
}
}


//k均值算法实现
import java.io.BufferedReader;
import java.io.FileReader;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Date;
import java.util.HashSet;
import java.util.List;
import java.util.Random;
import java.util.Scanner;
import java.util.Set;

public class Kmeans {

/**
* 计算两个元组间的欧氏距离(欧几里得度量)
* @param tuple1 第一个向量
* @param tuple2 第二个向量
* @return
*/
private double getDistXY(double[] tuple1,double[] tuple2)
{
double sum=0;
for(int i=0;i<tuple1.length;i++)
{
sum+=Math.pow(tuple1[i]-tuple2[i],2);
}
return Math.sqrt(sum);
}

/**
* 根据质心,决定当前元组属于哪个簇
* @param centers  中心点,多个中心点
* @param tuple  向量
* @param K 簇数
* @return
*/
private int clusterOfDataSet(double[][] centers,double[] tuple,int k)
{
double dist=getDistXY(centers[0],tuple);
double temp;
int label=0;
for(int i=1;i<k;i++)
{
temp=getDistXY(centers[i],tuple);
if(temp<dist)
{
dist=temp;
label=i;
}
}
System.out.println(String.format("lable=%d", label));
return label;
}

/**
* 所有对象与形心的误差之和
* @param clusters 簇
* @param centers  中心点
* 第i个簇就与第i个中心点是相对应的,因为在方法clusterOfDataSet中的lable与i相互对应
* @return
*/
private double getVar(Cluster[] clusters,double[][] centers)
{
double var=0;
int k=clusters.length;//簇数
for(int i=0;i<k;i++)
{
double[][] d=clusters[i].getCluster();
//簇中的每条对象与形心的误差之和
for(int j=0;j<clusters[i].getSize();j++)
{
var+=getDistXY(d[j],centers[i]);
}
}
return var;
}

/**
* 得到簇里面的中心点
* @param cluster
* @return
*/
private double[] getCenter(Cluster cluster,int dimNum)
{
int num=cluster.getSize();
double[] temp=new double[dimNum];
double[][] d=cluster.getCluster();
for(int i=0;i<num;i++)
{
for(int j=0;j<dimNum;j++)
{
temp[j]+=d[i][j];
}
}
for(int j=0;j<dimNum;j++)
{
temp[j]/=num;
}
return temp;
}

/**
* K-means算法
* @param dataSet数据集
* @param dataNum行数
* @param dimNum列数
* @param k将要划分的簇的数量
* @param seed种子
*/
private void printKmeans(double[][] dataSet,int dataNum,int dimNum,int k,int seed)
{
printTwoDimensionalArray(dataSet);
//K个簇初始化
Cluster[] clusters=new Cluster[k];
for(int i=0;i<clusters.length;i++)
{
clusters[i]=new Cluster(dataNum,dimNum);
}
int i=0;
//一开始随机选取K条记录的值作为K个簇的质心(均值)
Random random=new Random(seed);
Set<Integer> set=new HashSet<Integer>();
while(set.size()<=k)
{
int select=random.nextInt(dataNum);
set.add(select);
if(set.size()==k) break;
}
List<Integer> list=new ArrayList<Integer>(set);
double[][] centers=new double[k][dimNum];
for(i=0;i<k;i++)
{
System.out.println("被选中的数据下标是:"+list.get(i));
for(int j=0;j<dimNum;j++)
{
centers[i][j]=dataSet[list.get(i)][j];
}
}

System.out.println("中心点数据centers:");
printTwoDimensionalArray(centers);

int lable=0;
//根据默认的质心给簇赋值
for(i=0;i<dataNum;i++)
{
lable=clusterOfDataSet(centers,dataSet[i],k);//lable的取值为0,1,...k-1
//将dataSet第i条数据放到 第lable个簇中,这里的第lable个簇与第lable个中心点相互对应
for(int column=0;column<dimNum;column++)
{
Cluster tempCluster=clusters[lable];
double[][] temp=tempCluster.getCluster();
int size=tempCluster.getSize();
temp[size][column]=dataSet[i][column];
}
clusters[lable].setSize(clusters[lable].getSize()+1);
}

double oldVar=-1;
double newVar=getVar(clusters,centers);
System.out.println("初始的整体误差平方和为:"+newVar);
int t=0;
while(Math.abs(newVar-oldVar)>=1)//当新旧函数值相差不到1即准则函数值不发生明显变化时,算法终止
{
System.out.println("第"+(++t)+"次迭代开始:");
for(i=0;i<k;i++)
{
centers[i]=getCenter(clusters[i],dimNum);
}
System.out.println("中心点数据centers:");
printTwoDimensionalArray(centers);
oldVar=newVar;
newVar=getVar(clusters,centers);
//清空每个簇
for(i=0;i<clusters.length;i++)
{
clusters[i].setSize(0);
}
//根据新的中心店获得新的簇
for(i=0;i<dataNum;i++)
{
lable=clusterOfDataSet(centers,dataSet[i],k);//lable的取值为0,1,...k-1
//将dataSet第i条数据放到 第lable个簇中
for(int column=0;column<dimNum;column++)
{
Cluster tempCluster=clusters[lable];
double[][] temp=tempCluster.getCluster();
int size=tempCluster.getSize();
temp[size][column]=dataSet[i][column];
}
clusters[lable].setSize(clusters[lable].getSize()+1);
}
System.out.println("此次迭代之后的整体误差平方和为:"+newVar);
}
System.out.println("The result is:\n");
for(i=0;i<k;i++)
{
System.out.println(String.format("第%d个簇中是", i));
Cluster c=clusters[i];
printClustery(c);
}
}

public static void main(String[] args) throws Exception
{
//获得当前时间
Date date=new Date();

Scanner sc=new Scanner(System.in);
System.out.println("请输入文件的绝对路径");
String fileName=sc.nextLine();
System.out.println("请输入样本的维数");//4
int dimNum=sc.nextInt();
System.out.println("请输入样本的数量(行数)");//150
int dataNum=sc.nextInt();
System.out.println("请输入最终形成的簇数");//3
int clusterNum=sc.nextInt();
Kmeans k=new Kmeans();
double[][] dataSet=k.readData(fileName,dataNum,dimNum);
//k.printTwoDimensionalArray(k.readData(fileName, dataNum, dimNum));
date=new Date();
long time=date.getTime();
SimpleDateFormat sdf=new SimpleDateFormat("yyyy-MM-dd hh:mm:ss-SSS");
System.out.println(sdf.format(time));

k.printKmeans(dataSet,dataNum,dimNum,clusterNum,100);

date=new Date();
time=date.getTime();
System.out.println(sdf.format(time));
}

/**
* 从文件中读取数据,并且返回一个二维数组
* @param fileName文件的全名
* @param dimNum将要保存的数据的列数
* @param dataNum文件中有效数据的行数
* @return
* @throws Exception
*/
private double[][] readData(String fileName,int dataNum,int dimNum) throws Exception
{
double[][] data=new double[dataNum][dimNum];
FileReader fr=null;
BufferedReader br=null;
try
{
fr=new FileReader(fileName);
br=new BufferedReader(fr);
//存放数据的临时变量
String lineData=null;
String[] splitData=null;
int line=0;
//从文件的头部开始读取,到@data处停止
while(br.ready())
{
lineData=br.readLine();
if(lineData.toUpperCase().equals("@DATA"))
{
break;
}
}
while(br.ready())
{
lineData=br.readLine();
splitData=lineData.split(",");
if(splitData.length>1)
{
for(int i=0;i<splitData.length;i++)
{
if(IsDouble(splitData[i]))
{
data[line][i]=Double.valueOf(splitData[i]);
System.out.println(String.format("data[%d][%d]=%f", line,i,data[line][i]));
}
}
line++;
}
}
return data;
}
catch(Exception ex)
{
throw new Exception(ex);
}
finally
{
fr.close();
br.close();
}
}

/**
* 判断一个字符串可否转为double类型的数据
* @param str
* @return
*/
private boolean IsDouble(String str)
{
try
{
double d=Double.valueOf(str);
return true;
}
catch(NumberFormatException ex)
{
//System.out.println(ex.getMessage()+"\n"+ex.toString());
return false;
}
}

/**
* 打印输出二维数组
* @param array
*/
private void printTwoDimensionalArray(double[][] array)
{
for(int i=0;i<array.length;i++)
{
for(int j=0;j<array[0].length;j++)
{
System.out.print(String.format("%6f   ",array[i][j]));
}
System.out.println();
}
}

/**
* 打印输出簇中的元素
* @param cluster
*/
private void printClustery(Cluster cluster)
{
for(int i=0;i<cluster.getSize();i++)
{
double[][] array=cluster.getCluster();
for(int j=0;j<array[0].length;j++)
{
System.out.print(String.format("%6f   ",array[i][j]));
}
System.out.println();
}
}

}


测试数据,这里给出iris数据的部分数据,其详细数据,可以在UCI机器学习的数据集中查找并下载,网址(http://archive.ics.uci.edu/ml/)

@RELATION iris

@ATTRIBUTE sepallength  REAL
@ATTRIBUTE sepalwidth   REAL
@ATTRIBUTE petallength  REAL
@ATTRIBUTE petalwidth   REAL
@ATTRIBUTE class    {Iris-setosa,Iris-versicolor,Iris-virginica}

@DATA
5.1,3.5,1.4,0.2,Iris-setosa
4.9,3.0,1.4,0.2,Iris-setosa
4.7,3.2,1.3,0.2,Iris-setosa
4.6,3.1,1.5,0.2,Iris-setosa
5.0,3.6,1.4,0.2,Iris-setosa
5.4,3.9,1.7,0.4,Iris-setosa
4.6,3.4,1.4,0.3,Iris-setosa
5.0,3.4,1.5,0.2,Iris-setosa


备注:根据/article/1851414.html 上面的C++代码改编的
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: