您的位置:首页 > 编程语言 > Java开发

聚类算法之kmeans算法java版本

2012-04-23 09:41 1896 查看

      聚类的意思很明确,物以类聚,把类似的事物放在一起。
      聚类算法是web智能中很重要的一步,可运用在社交,新闻,电商等各种应用中,我打算专门开个分类讲解聚类各种算法的java版实现。
     首先介绍kmeans算法。
     kmeans算法的速度很快,性能良好,几乎是应用最广泛的,它需要先指定聚类的个数k,然后根据k值来自动分出k个类别集合。
     举个例子,某某教练在得到全队的数据后,想把这些球员自动分成不同的组别,你得问教练需要分成几个组,他回答你k个,ok可以开始了,在解决这个问题之前有必要详细了解自己需要达到的目的:根据教练给出的k值,呈现出k个组,每个组的队员是相似的。
     首先,我们创建球员类。 

 

package kmeans;

/**
* 球员
*
* @author 阿飞哥
*
*/
public class Player {

private int id;
private String name;

private int age;

/* 得分 */
@KmeanField
private double goal;

/* 助攻 */
//@KmeanField
private double assists;

/* 篮板 */
//@KmeanField
private double backboard;

/* 抢断 */
//@KmeanField
private double steals;

public int getId() {
return id;
}

public void setId(int id) {
this.id = id;
}

public String getName() {
return name;
}

public void setName(String name) {
this.name = name;
}

public int getAge() {
return age;
}

public void setAge(int age) {
this.age = age;
}

public double getGoal() {
return goal;
}

public void setGoal(double goal) {
this.goal = goal;
}

public double getAssists() {
return assists;
}

public void setAssists(double assists) {
this.assists = assists;
}

public double getBackboard() {
return backboard;
}

public void setBackboard(double backboard) {
this.backboard = backboard;
}

public double getSteals() {
return steals;
}

public void setSteals(double steals) {
this.steals = steals;
}

}

 

        

   @KmeanField这个注解是自定义的,用来标示这个属性是否是算法需要的维度。
代码如下 

package kmeans;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

/**
* 在对象的属性上标注此注释,
* 表示纳入kmeans算法,仅支持数值类属性
* @author 阿飞哥
*/
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.FIELD)
public @interface KmeanField {
}

 

接下来看看最核心的kmeans算法,具体实现过程如下:
1,初始化k个聚类中心
2,计算出每个对象跟这k个中心的距离(相似度计算,这个下面会提到),假如x这个对象跟y这个中心的距离最小(相似度最大),那么x属于y这个中心。这一步就可以得到初步的k个聚类
3,在第二步得到的每个聚类分别计算出新的聚类中心,和旧的中心比对,假如不相同,则继续第2步,直到新旧两个中心相同,说明聚类不可变,已经成功

实现代码如下:

package kmeans;

import java.lang.annotation.Annotation;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.List;

/**
*
* @author 阿飞哥
*
*/
public class Kmeans<T> {

/**
* 所有数据列表
*/
private List<T> players = new ArrayList<T>();

/**
* 数据类别
*/
private Class<T> classT;

/**
* 初始化列表
*/
private List<T> initPlayers;

/**
* 需要纳入kmeans算法的属性名称
*/
private List<String> fieldNames = new ArrayList<String>();

/**
* 分类数
*/
private int k = 1;

public Kmeans() {

}

/**
* 初始化列表
*
* @param list
* @param k
*/
public Kmeans(List<T> list, int k) {
this.players = list;
this.k = k;
T t = list.get(0);
this.classT = (Class<T>) t.getClass();
Field[] fields = this.classT.getDeclaredFields();
for (int i = 0; i < fields.length; i++) {
Annotation kmeansAnnotation = fields[i]
.getAnnotation(KmeanField.class);
if (kmeansAnnotation != null) {
fieldNames.add(fields[i].getName());
}

}

initPlayers = new ArrayList<T>();
for (int i = 0; i < k; i++) {
initPlayers.add(players.get(i));
}
}

public List<T>[] comput() {
List<T>[] results = new ArrayList[k];

boolean centerchange = true;
while (centerchange) {
centerchange = false;
for (int i = 0; i < k; i++) {
results[i] = new ArrayList<T>();
}
for (int i = 0; i < players.size(); i++) {
T p = players.get(i);
double[] dists = new double[k];
for (int j = 0; j < initPlayers.size(); j++) {
T initP = initPlayers.get(j);
/* 计算距离 */
double dist = distance(initP, p);
dists[j] = dist;
}

int dist_index = computOrder(dists);
results[dist_index].add(p);
}

for (int i = 0; i < k; i++) {
T player_new = findNewCenter(results[i]);
T player_old = initPlayers.get(i);
if (!IsPlayerEqual(player_new, player_old)) {
centerchange = true;
initPlayers.set(i, player_new);
}

}

}

return results;
}

/**
* 比较是否两个对象是否属性一致
*
* @param p1
* @param p2
* @return
*/
public boolean IsPlayerEqual(T p1, T p2) {
if (p1 == p2) {
return true;
}
if (p1 == null || p2 == null) {
return false;
}

boolean flag = true;
try {
for (int i = 0; i < fieldNames.size(); i++) {
String fieldName=fieldNames.get(i);
String getName = "get"
+ fieldName.substring(0, 1).toUpperCase()
+ fieldName.substring(1);
Object value1 = invokeMethod(p1,getName,null);
Object value2 = invokeMethod(p2,getName,null);
if (!value1.equals(value2)) {
flag = false;
break;
}
}
} catch (Exception e) {
e.printStackTrace();
flag = false;
}

return flag;
}

/**
* 得到新聚类中心对象
*
* @param ps
* @return
*/
public T findNewCenter(List<T> ps) {
try {
T t = classT.newInstance();
if (ps == null || ps.size() == 0) {
return t;
}

double[] ds = new double[fieldNames.size()];
for (T vo : ps) {
for (int i = 0; i < fieldNames.size(); i++) {
String fieldName=fieldNames.get(i);
String getName = "get"
+ fieldName.substring(0, 1).toUpperCase()
+ fieldName.substring(1);
Object obj=invokeMethod(vo,getName,null);
Double fv=(obj==null?0:Double.parseDouble(obj+""));
ds[i] += fv;
}

}

for (int i = 0; i < fieldNames.size(); i++) {
ds[i] = ds[i] / ps.size();
String fieldName = fieldNames.get(i);

/* 给对象设值 */
String setName = "set"
+ fieldName.substring(0, 1).toUpperCase()
+ fieldName.substring(1);

invokeMethod(t,setName,new Class[]{double.class},ds[i]);

}

return t;
} catch (Exception ex) {
ex.printStackTrace();
}
return null;

}

/**
* 得到最短距离,并返回最短距离索引
*
* @param dists
* @return
*/
public int computOrder(double[] dists) {
double min = 0;
int index = 0;
for (int i = 0; i < dists.length - 1; i++) {
double dist0 = dists[i];
if (i == 0) {
min = dist0;
index = 0;
}
double dist1 = dists[i + 1];
if (min > dist1) {
min = dist1;
index = i + 1;
}
}

return index;
}

/**
* 计算距离(相似性) 采用欧几里得算法
*
* @param p0
* @param p1
* @return
*/
public double distance(T p0, T p1) {
double dis = 0;
try {

for (int i = 0; i < fieldNames.size(); i++) {
String fieldName = fieldNames.get(i);
String getName = "get"
+ fieldName.substring(0, 1).toUpperCase()
+ fieldName.substring(1);

Double field0Value=Double.parseDouble(invokeMethod(p0,getName,null)+"");
Double field1Value=Double.parseDouble(invokeMethod(p1,getName,null)+"");
dis += Math.pow(field0Value - field1Value, 2);
}

} catch (Exception ex) {
ex.printStackTrace();
}
return Math.sqrt(dis);

}

/*------公共方法-----*/
public Object invokeMethod(Object owner, String methodName,Class[] argsClass,
Object... args) {
Class ownerClass = owner.getClass();
try {
Method method=ownerClass.getDeclaredMethod(methodName,argsClass);
return method.invoke(owner, args);
} catch (SecurityException e) {
e.printStackTrace();
} catch (NoSuchMethodException e) {
e.printStackTrace();
} catch (Exception ex) {
ex.printStackTrace();
}

return null;
}

}

 

最后咱们测试一下:

package kmeans;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;

public class TestMain {

public static void main(String[] args) {
List<Player> listPlayers=new ArrayList<Player>();

for(int i=0;i<15;i++){

Player p1=new Player();
p1.setName("afei-"+i);
p1.setAssists(i);
p1.setBackboard(i);

//p1.setGoal(new Random(100*i).nextDouble());
p1.setGoal(i*10);
p1.setSteals(i);
//listPlayers.add(p1);
}

Player p1=new Player();
p1.setName("afei1");
p1.setGoal(1);
p1.setAssists(8);
listPlayers.add(p1);

Player p2=new Player();
p2.setName("afei2");
p2.setGoal(2);
listPlayers.add(p2);

Player p3=new Player();
p3.setName("afei3");
p3.setGoal(3);
listPlayers.add(p3);

Player p4=new Player();
p4.setName("afei4");
p4.setGoal(7);
listPlayers.add(p4);

Player p5=new Player();
p5.setName("afei5");
p5.setGoal(8);
listPlayers.add(p5);

Player p6=new Player();
p6.setName("afei6");
p6.setGoal(25);
listPlayers.add(p6);

Player p7=new Player();
p7.setName("afei7");
p7.setGoal(26);
listPlayers.add(p7);

Player p8=new Player();
p8.setName("afei8");
p8.setGoal(27);
listPlayers.add(p8);

Player p9=new Player();
p9.setName("afei9");
p9.setGoal(28);
listPlayers.add(p9);

Kmeans<Player> kmeans = new Kmeans<Player>(listPlayers,3);
List<Player>[] results = kmeans.comput();
for (int i = 0; i < results.length; i++) {
System.out.println("===========类别" + (i + 1) + "================");
List<Player> list = results[i];
for (Player p : list) {
System.out.println(p.getName() + "--->"
+ p.getGoal() + "," + p.getAssists() + ","
+ p.getSteals() + "," + p.getBackboard());
}
}

}

}

 

结果如下

  这个里面涉及到相似度算法,事实证明欧几里得距离算法的实践效果是最优的。
  最后说说kmeans算法的不足:可以看到只能针对数字类型的属性(维),对于其他类型的除非选定合适的数值度量。

 

By 阿飞哥 转载请说明
 


       

 

(window.slotbydup = window.slotbydup || []).push({ id: "u5894387", container: "_0hv0l6ey3zro", async: true });
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签:  kmeans Java