您的位置:首页 > 编程语言 > Java开发

K-means算法原理以及java实现

2017-06-17 23:05 309 查看
我做了一个小例子,将k-means算法用在我最近做的一个系统中。以下介绍k-means算法。

(1)k-means算法的简介

本系统使用k-means算法来计算一维数据的聚集程度,实现圈子的划分,这里的一维数据是所有的点,用A、B、C、D来表示每一个点,任意两个点之间的最短距离的计算方法已经封装成为接口,直接调用即可。K-Means算法的基本思想是初始随机给定K个簇中心,按照最邻近原则把待分类样本点分到各个簇。然后重新计算各个簇的质心,从而确定新的簇心。一直迭代,直到簇心不再发生改变为止,算法结束。


(2)k-means算法的设计思想

如图所示,图中共有5个点,分别是A、B、C、D、E,灰色的点是种子点,也就是用来找点群的点,因为有两个种子点,所以k=2.



(3)k-means算法的具体步骤

如上图所示,随机在图中取K(这里K=2)个种子点。
①然后对图中的所有点求到这K个种子点的距离,假如点Pi离种子点Si最近,那么Pi属于Si点群。(上图中,可以看到A,B属于上面的种子点,C,D,E属于下面中部的种子点)。
②接下来,需要移动种子点到属于他的“点群”的中心。(见图上的第三步)
③然后重复第2)和第3)步,直到,种子点没有移动(可以看到上图中的第四步上面的种子点聚合了A,B,C,下面的种子点聚合了D,E)。


我做的例子中对于k-means的应用具体步骤:

①随机在数据库中取K个种子点,一般情况下,取数据库所存储数据的前n个数据作为初始种子点。
②然后对数据库中的所有点求到这K个种子点的距离,假如点Pi离种子点
Si最近,那么Pi属于Si点群。这样就将所有点划分成了k个圈子。
③接下来,要移动种子点到属于他的“点群”的中心,移动思路是使得点群中所有点到中心点的距离之和最短,我们将这个中心点称点群的中心。求最短距离之和的过程中,本系统使用了Dijkstra算法,将求某个点到其余所有点的最短距离之和的过程写成了接口,返回更新后的种子点和最短距离之和。
④然后重复第(2)和第(3)步,直到,种子点不再发生变化。则分圈结束。


(4)k-means算法的流程图



(5)k-means算法的结果测试:

本系统设定将所有的点分成三个圈子,程序总共循环分圈了两次,进行两次分圈之后,中心点不再发生变化,于是程序停止运行。



上图是第一次分圈的过程,将A、B、C 作为初始种子点,计算其余各点距离这三个种子点的距离,距离A、B、C哪个点最近,就将它归为这个点群中。下图是第一次分圈结束后的结果,最后的分圈结果包含三行,第一行是点群A,其中D、E、F都属于A点所在圈子,第二行是点群B,其中B、H属于B点所在圈子,第三行是点群C,其中G点属于C点所在的圈子。



经过一次分圈之后,本系统的中心种子点从A,B,C三点变成了A,H,C,如下图,是系统将A、H、C 作为更新后的种子点,计算其余各点距离这三个种子点的距离,距离A、H、C哪个点最近,就将它归为这个点群中。最后的分圈结果包含三行,第一行是点群A,其中B、D、E、F都属于A点所在圈子,第二行是点群H,其中G属于H点所在圈子,第三行是点群C,只包含C点本身。

第二次分圈后,中心点不再改变,程序终止。



第二次分圈结束后结果:



这也是最终分圈的结果。

(5) k-means算法的java实现:

public class Basickmeanstest {

//给定需要分圈的初始节点
static String[] p = { "A", "B", "C", "D", "F","E","G", "H" };
public static void main(String[] args) throws Exception {
//k表示分圈的个数,这里需要分为3个圈,k根据数据量大小自定义。
int  k = 3;

//二维数组g用来存储分圈之后的结果
//二维数组,每一行代表一个圈子,单独一行中的所有点表示这个圈子之内的元素。
String[][] g;
g = cluster(p, k);

//将分圈结果输出
for (int i = 0; i < g.length; i++) {
for (int j = 0; j < g[i].length; j++) {
System.out.print(g[i][j]+'\t');
}
System.out.println();
}
}

public static String[][] cluster(String[] p, int k) throws Exception {

//c存放原始的中心点
String[] c = new String[k];

//nc存放更新之后的中心点
String[] nc = new String[k];
String[][] g;
for (int i = 0; i < k; i++) {
c[i] = p[i];
}
for(int i=0; i<c.length; i++ ){
}
while (true) {
g = group(p, c);
for(int i=0;i<g.length;i++){
for(int j=0;j<g[i].length;j++)
System.out.print(g[i][j]+' ');
System.out.println();
}
System.out.println("----------------------");
for (int i = 0; i < g.length; i++) {
nc[i] = center(g[i]);

}
//当更新后的中心点和初始中心点不同的时候,更新中心点。
if (!equal(nc, c)) {
c = nc;
nc = new String[k];
} else {
break;
}
}
return g;
}

public static Object getMinValue(HashMap<String, Integer> map) {
if (map == null)
return null;
Collection<Integer> c = map.values();
Object[] obj = c.toArray();
Arrays.sort(obj);
return obj[0];
}

//这里是我更新中心点的依据,调用的是我写好的接口,依据是在有向图中,若某个点到其余所有点的距离之和最短,则将这个点作为新的中心点。
public static String center(String[] p) throws Exception {
return Main.getShort(p);

}

//这是分圈的核心方法
public static String[][] group(String[] p, String[] c) throws Exception {
int[] gi = new int[p.length];
for (int i = 0; i < p.length; i++) {
// 存放距离
int[] d = new int[c.length];
// 计算到每个聚类中心的距离
for (int j = 0; j < c.length; j++) {
d[j] = distance(p[i], c[j]);
}
// 找出最小距离
int ci = min(d);
System.out.println("较小的值为"+d[ci]);
System.out.println("ci="+ci);
// 标记属于哪一组
//ci表示的是属于哪一个组,在本例中,k=3,分为3个全,则ci的可能取值为0,1,2.当ci=0的时候,表示属于第一个圈子,即会被分到结果g的第一行,以此类推。
gi[i] = ci;
}
String[][] g = new String[c.length][];
// 遍历每个聚类中心,分组
for (int i = 0; i < c.length; i++) {
// 中间变量,记录聚类后每一组的大小
int s = 0;
// 计算每一组的长度
for (int j = 0; j < gi.length; j++)
if (gi[j] == i)
s++;
// 存储每一组的成员
g[i] = new String[s];
s = 0;
// 根据分组标记将各元素归位
for (int j = 0; j < gi.length; j++)
if (gi[j] == i) {
g[i][s] = p[j];
s++;
}
}
return g;
}

//这是求两个节点之间的距离的方法
//此方法是我已经写好的接口,距离直接从数据库中读出来即可。
public static int distance(String x, String y) throws Exception {
DijstTestDao dijstra = DijstraFactory.getDijistra();
int len;
if(x.equals(y)) {
len= 0;
} else {
Main.getShort(p);
len = Main.findShortInstance(x, y);

}
System.out.println("第一个点"+x+"和第二个点"+y+"之间距离为:"+len);
return len;
}

public static int min(int[] p) {
int i = 0;
int m = p[0];
for (int j = 1; j < p.length; j++) {
if (p[j] < m) {
i = j;
m = p[j];
}
}
return i;
}

public static boolean equal(String[] a, String[] b) {
if (a.length != b.length)
return false;
else {
for (int i = 0; i < a.length; i++) {
if (a[i] != b[i])
return false;
}
}
return true;
}
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: