K-means算法(Java实现)
2016-04-26 17:36
471 查看
K-means算法是聚类中最简单的算法,在Weka、R、Matlab、SQL Server Business Intelligence Development Studio均可快速调用。
K-means算法原理
这里给出java实现的具体代码
测试数据,这里给出iris数据的部分数据,其详细数据,可以在UCI机器学习的数据集中查找并下载,网址(http://archive.ics.uci.edu/ml/)
备注:根据/article/1851414.html 上面的C++代码改编的
K-means算法原理
这里给出java实现的具体代码
public class Cluster { private double[][] cluster; private int size;//用来记录此簇中有多少条数据 public double[][] getCluster() { return cluster; } public void setCluster(double[][] cluster) { this.cluster = cluster; } public int getSize() { return size; } public void setSize(int size) { this.size = size; } public Cluster(int dataNum,int dimNum) { super(); this.cluster = new double[dataNum][dimNum]; this.size=0; } }
//k均值算法实现 import java.io.BufferedReader; import java.io.FileReader; import java.text.SimpleDateFormat; import java.util.ArrayList; import java.util.Date; import java.util.HashSet; import java.util.List; import java.util.Random; import java.util.Scanner; import java.util.Set; public class Kmeans { /** * 计算两个元组间的欧氏距离(欧几里得度量) * @param tuple1 第一个向量 * @param tuple2 第二个向量 * @return */ private double getDistXY(double[] tuple1,double[] tuple2) { double sum=0; for(int i=0;i<tuple1.length;i++) { sum+=Math.pow(tuple1[i]-tuple2[i],2); } return Math.sqrt(sum); } /** * 根据质心,决定当前元组属于哪个簇 * @param centers 中心点,多个中心点 * @param tuple 向量 * @param K 簇数 * @return */ private int clusterOfDataSet(double[][] centers,double[] tuple,int k) { double dist=getDistXY(centers[0],tuple); double temp; int label=0; for(int i=1;i<k;i++) { temp=getDistXY(centers[i],tuple); if(temp<dist) { dist=temp; label=i; } } System.out.println(String.format("lable=%d", label)); return label; } /** * 所有对象与形心的误差之和 * @param clusters 簇 * @param centers 中心点 * 第i个簇就与第i个中心点是相对应的,因为在方法clusterOfDataSet中的lable与i相互对应 * @return */ private double getVar(Cluster[] clusters,double[][] centers) { double var=0; int k=clusters.length;//簇数 for(int i=0;i<k;i++) { double[][] d=clusters[i].getCluster(); //簇中的每条对象与形心的误差之和 for(int j=0;j<clusters[i].getSize();j++) { var+=getDistXY(d[j],centers[i]); } } return var; } /** * 得到簇里面的中心点 * @param cluster * @return */ private double[] getCenter(Cluster cluster,int dimNum) { int num=cluster.getSize(); double[] temp=new double[dimNum]; double[][] d=cluster.getCluster(); for(int i=0;i<num;i++) { for(int j=0;j<dimNum;j++) { temp[j]+=d[i][j]; } } for(int j=0;j<dimNum;j++) { temp[j]/=num; } return temp; } /** * K-means算法 * @param dataSet数据集 * @param dataNum行数 * @param dimNum列数 * @param k将要划分的簇的数量 * @param seed种子 */ private void printKmeans(double[][] dataSet,int dataNum,int dimNum,int k,int seed) { printTwoDimensionalArray(dataSet); //K个簇初始化 Cluster[] clusters=new Cluster[k]; for(int i=0;i<clusters.length;i++) { clusters[i]=new Cluster(dataNum,dimNum); } int i=0; //一开始随机选取K条记录的值作为K个簇的质心(均值) Random random=new Random(seed); Set<Integer> set=new HashSet<Integer>(); while(set.size()<=k) { int select=random.nextInt(dataNum); set.add(select); if(set.size()==k) break; } List<Integer> list=new ArrayList<Integer>(set); double[][] centers=new double[k][dimNum]; for(i=0;i<k;i++) { System.out.println("被选中的数据下标是:"+list.get(i)); for(int j=0;j<dimNum;j++) { centers[i][j]=dataSet[list.get(i)][j]; } } System.out.println("中心点数据centers:"); printTwoDimensionalArray(centers); int lable=0; //根据默认的质心给簇赋值 for(i=0;i<dataNum;i++) { lable=clusterOfDataSet(centers,dataSet[i],k);//lable的取值为0,1,...k-1 //将dataSet第i条数据放到 第lable个簇中,这里的第lable个簇与第lable个中心点相互对应 for(int column=0;column<dimNum;column++) { Cluster tempCluster=clusters[lable]; double[][] temp=tempCluster.getCluster(); int size=tempCluster.getSize(); temp[size][column]=dataSet[i][column]; } clusters[lable].setSize(clusters[lable].getSize()+1); } double oldVar=-1; double newVar=getVar(clusters,centers); System.out.println("初始的整体误差平方和为:"+newVar); int t=0; while(Math.abs(newVar-oldVar)>=1)//当新旧函数值相差不到1即准则函数值不发生明显变化时,算法终止 { System.out.println("第"+(++t)+"次迭代开始:"); for(i=0;i<k;i++) { centers[i]=getCenter(clusters[i],dimNum); } System.out.println("中心点数据centers:"); printTwoDimensionalArray(centers); oldVar=newVar; newVar=getVar(clusters,centers); //清空每个簇 for(i=0;i<clusters.length;i++) { clusters[i].setSize(0); } //根据新的中心店获得新的簇 for(i=0;i<dataNum;i++) { lable=clusterOfDataSet(centers,dataSet[i],k);//lable的取值为0,1,...k-1 //将dataSet第i条数据放到 第lable个簇中 for(int column=0;column<dimNum;column++) { Cluster tempCluster=clusters[lable]; double[][] temp=tempCluster.getCluster(); int size=tempCluster.getSize(); temp[size][column]=dataSet[i][column]; } clusters[lable].setSize(clusters[lable].getSize()+1); } System.out.println("此次迭代之后的整体误差平方和为:"+newVar); } System.out.println("The result is:\n"); for(i=0;i<k;i++) { System.out.println(String.format("第%d个簇中是", i)); Cluster c=clusters[i]; printClustery(c); } } public static void main(String[] args) throws Exception { //获得当前时间 Date date=new Date(); Scanner sc=new Scanner(System.in); System.out.println("请输入文件的绝对路径"); String fileName=sc.nextLine(); System.out.println("请输入样本的维数");//4 int dimNum=sc.nextInt(); System.out.println("请输入样本的数量(行数)");//150 int dataNum=sc.nextInt(); System.out.println("请输入最终形成的簇数");//3 int clusterNum=sc.nextInt(); Kmeans k=new Kmeans(); double[][] dataSet=k.readData(fileName,dataNum,dimNum); //k.printTwoDimensionalArray(k.readData(fileName, dataNum, dimNum)); date=new Date(); long time=date.getTime(); SimpleDateFormat sdf=new SimpleDateFormat("yyyy-MM-dd hh:mm:ss-SSS"); System.out.println(sdf.format(time)); k.printKmeans(dataSet,dataNum,dimNum,clusterNum,100); date=new Date(); time=date.getTime(); System.out.println(sdf.format(time)); } /** * 从文件中读取数据,并且返回一个二维数组 * @param fileName文件的全名 * @param dimNum将要保存的数据的列数 * @param dataNum文件中有效数据的行数 * @return * @throws Exception */ private double[][] readData(String fileName,int dataNum,int dimNum) throws Exception { double[][] data=new double[dataNum][dimNum]; FileReader fr=null; BufferedReader br=null; try { fr=new FileReader(fileName); br=new BufferedReader(fr); //存放数据的临时变量 String lineData=null; String[] splitData=null; int line=0; //从文件的头部开始读取,到@data处停止 while(br.ready()) { lineData=br.readLine(); if(lineData.toUpperCase().equals("@DATA")) { break; } } while(br.ready()) { lineData=br.readLine(); splitData=lineData.split(","); if(splitData.length>1) { for(int i=0;i<splitData.length;i++) { if(IsDouble(splitData[i])) { data[line][i]=Double.valueOf(splitData[i]); System.out.println(String.format("data[%d][%d]=%f", line,i,data[line][i])); } } line++; } } return data; } catch(Exception ex) { throw new Exception(ex); } finally { fr.close(); br.close(); } } /** * 判断一个字符串可否转为double类型的数据 * @param str * @return */ private boolean IsDouble(String str) { try { double d=Double.valueOf(str); return true; } catch(NumberFormatException ex) { //System.out.println(ex.getMessage()+"\n"+ex.toString()); return false; } } /** * 打印输出二维数组 * @param array */ private void printTwoDimensionalArray(double[][] array) { for(int i=0;i<array.length;i++) { for(int j=0;j<array[0].length;j++) { System.out.print(String.format("%6f ",array[i][j])); } System.out.println(); } } /** * 打印输出簇中的元素 * @param cluster */ private void printClustery(Cluster cluster) { for(int i=0;i<cluster.getSize();i++) { double[][] array=cluster.getCluster(); for(int j=0;j<array[0].length;j++) { System.out.print(String.format("%6f ",array[i][j])); } System.out.println(); } } }
测试数据,这里给出iris数据的部分数据,其详细数据,可以在UCI机器学习的数据集中查找并下载,网址(http://archive.ics.uci.edu/ml/)
@RELATION iris @ATTRIBUTE sepallength REAL @ATTRIBUTE sepalwidth REAL @ATTRIBUTE petallength REAL @ATTRIBUTE petalwidth REAL @ATTRIBUTE class {Iris-setosa,Iris-versicolor,Iris-virginica} @DATA 5.1,3.5,1.4,0.2,Iris-setosa 4.9,3.0,1.4,0.2,Iris-setosa 4.7,3.2,1.3,0.2,Iris-setosa 4.6,3.1,1.5,0.2,Iris-setosa 5.0,3.6,1.4,0.2,Iris-setosa 5.4,3.9,1.7,0.4,Iris-setosa 4.6,3.4,1.4,0.3,Iris-setosa 5.0,3.4,1.5,0.2,Iris-setosa
备注:根据/article/1851414.html 上面的C++代码改编的
相关文章推荐
- Java集合---ConcurrentHashMap原理分析
- Spring自动注入properties文件
- 修改eclipse自动生成的comments中的author名字
- java如何调用服务端的WSDL接口
- java如何调用服务端的WSDL接口
- Java简繁转换ZHConverter
- notepad++编辑的json文件copy到myEclipse后中文乱码
- springmvc的请求过滤器(session过期)
- java事物隔离性和传播
- java.io.File.deleteOnExit()-生成临时文件,删除临时文件
- java 时区问题 SimpleDateFormat 时区大全
- Java:使用synchronized和Lock对象获取对象锁
- Spring AOP 实现业务日志记录
- Strut2 Spring hibernate的优缺点
- [java] 多态实现的JVM调用过程
- RxJava学习(四)
- java的注入Deprecated
- Java Socket
- java基础数据遍历(4)删除数组中重复数字
- [Spring入门点滴]利用构造函数和setter注入