您的位置:首页 > 编程语言 > Java开发

K-Means 算法(Java)

2015-12-12 11:04 477 查看
kMeans算法原理见我的上一篇文章。这里介绍K-Means的Java实现方法,参考了Python的实现方法。

一、数据点的实现

package com.meachine.learning.kmeans;

import java.util.ArrayList;

/**
* 数据点,有n维数据
*
*/
public class Point {
private static int num;
private int id;
private int dimensioNum; // 维度
private ArrayList<Double> values;
private int clusterId = -1;
private double minDist = Integer.MAX_VALUE;

public Point() {
id = ++num;
values = new ArrayList<>();
}

public void add(double e) {
values.add(e);
dimensioNum++;
}
//------set与get省略----------
}


二、数据簇的实现

package com.meachine.learning.kmeans;

import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.Setter;
import lombok.ToString;

/**
* 簇<br>
* 数据集合的基本信息
*
*/
public class Cluster {
// 簇id
private int clusterId;
// 属于该簇的点的个数
private int numOfPoints;
// 簇中心点的信息
private Point center;

public Cluster(int id) {
this.clusterId = id;
numOfPoints = 0;
}

public Cluster(int id, Point center) {
this.clusterId = id;
this.center = center;
}
//----------set与get省略----------------
}


三、计算数据点距离

package com.meachine.learning.kmeans;

import java.util.List;

/**
* 计算距离接口
*
*/
public interface IDistance<T> {
public double getDis(List<T> p1, List<T> p2);
}


  

package com.meachine.learning.kmeans;

import java.util.List;

/**
* 欧式距离
*
*/
public class OujilidDistance<T extends Number> implements IDistance<T> {

public double getDis(List<T> a, List<T> b) {
if (a.size() != b.size()) {
throw new IllegalArgumentException("Size not compatible!");
}
double result = 0;
for (int i = 0; i < a.size(); i++) {
result += Math.pow((a.get(i).doubleValue() - b.get(i).doubleValue()), 2);
}
return Math.sqrt(result);
}

}


四、K-Means算法

  

package com.meachine.learning.kmeans;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;

/**
* K-Means算法
*
* @author Cang
*
*/
public class KMeans {
// 簇的个数
private int k;
// 维度,即多少个变量
private int dimensioNum;
// 最大迭代次数
private int maxItrNum = 100;
private IDistance<Double> distance;
private List<Point> points;
private List<Cluster> clusters = new ArrayList<Cluster>();
private String dataFileName = "D:/testSet.txt";

public KMeans(int k) {
this.k = k;
}

/**
* 初始化数据
*/
public void init() {
points = loadDataSet(dataFileName);
distance = new OujilidDistance<Double>();
initCluster();
}

/**
* 加载数据集
*
* @param fileName
* @return
*/
private List<Point> loadDataSet(String fileName) {
List<Point> points = new ArrayList<>();
File file = new File(fileName);
BufferedReader reader = null;
try {
reader = new BufferedReader(new FileReader(file));
String tempString = null;
int i = 0;
while ((tempString = reader.readLine()) != null) {
Point point = new Point();
dimensioNum = tempString.split("\t").length;
for (String data : tempString.split("\t")) {
point.add(Double.parseDouble(data));
}
points.add(point);
}
reader.close();
} catch (IOException e) {
e.printStackTrace();
}
return points;
}

/**
* 初始化簇中心
*
* @return
*/
private void initCluster() {
Random ran = new Random();
int id = 0;
while (id < k) {
Cluster c = new Cluster(++id);
int temp = ran.nextInt(points.size());
c.setCenter(points.get(temp));
clusters.add(c);
}
}

/**
* kMeans 具体算法
*/
public void clustering() {
boolean finished = false;
int count = 0;
while (!finished) {
// 寻找最近的中心
finished = true;
for (Point point : points) {
for (Cluster cluster : clusters) {

double minLen = distance.getDis(cluster.getCenter().getValues(),
point.getValues());
// 更新最小距离
if (minLen < point.getMinDist()) {
if (cluster.getClusterId() != point.getClusterId()) {
finished = false;
point.setClusterId(cluster.getClusterId());
}
point.setMinDist(minLen);
}
}
}
System.out.println("Cluster center info:");
for (Cluster string : clusters) {
System.out.println(string.getCenter().getValues());
}
// 更改中心的位置
changeCentroids();
// 超过循环次数,则跳出循环
if (++count > maxItrNum) {
finished = true;
}
}
}

/**
* 改变簇中心
*/
private void changeCentroids() {
for (Cluster cluster : clusters) {
ArrayList<Double> newCenterValue = new ArrayList<Double>();
Point newCenterPoint = new Point();
double result = 0;
for (int i = 0; i < dimensioNum; i++) {
for (Point point : points) {
if (point.getClusterId() == cluster.getClusterId()) {
result += point.getValues().get(i);
}
}
newCenterValue.add(result / points.size());
}
newCenterPoint.setClusterId(cluster.getClusterId());
newCenterPoint.setValues(newCenterValue);
cluster.setCenter(newCenterPoint);
}
}

public static void main(String[] args) {
KMeans kmeans = new KMeans(4);
kmeans.init();
kmeans.clustering();
}
}


  
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: