您的位置:首页 > 编程语言 > Java开发

使用Java实现K-Means聚类算法

2018-03-02 18:05 148 查看
第一次写博客,随便写写。
关于K-Means介绍很多,还不清楚可以查一些相关资料。
个人对其实现步骤简单总结为4步:
1.选出k值,随机出k个起始质心点。 
 
2.分别计算每个点和k个起始质点之间的距离,就近归类。 
 
3.最终中心点集可以划分为k类,分别计算每类中新的中心点。 
 

4.重复2,3步骤对所有点进行归类,如果当所有分类的质心点不再改变,则最终收敛。

下面贴代码。
1.入口类,基本读取数据源进行训练然后输出。 数据源文件和源码后面会补上。package com.hyr.kmeans;

import au.com.bytecode.opencsv.CSVReader;

import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

public class KmeansMain {

public static void main(String[] args) throws IOException {
// 读取数据源文件
CSVReader reader = new CSVReader(new FileReader("src/main/resources/data.csv")); // 数据源
FileWriter writer = new FileWriter("src/main/resources/out.csv");
List<String[]> myEntries = reader.readAll(); // 6.8, 12.6

// 转换数据点集
List<Point> points = new ArrayList<Point>(); // 数据点集
for (String[] entry : myEntries) {
points.add(new Point(Float.parseFloat(entry[0]), Float.parseFloat(entry[1])));
}

int k = 6; // K值
int type = 1;
KmeansModel model = Kmeans.run(points, k, type);

writer.write("==================== K is " + model.getK() + " , Object Funcion Value is " + model.getOfv() + " , calc_distance_type is " + model.getCalc_distance_type() + " ====================\n");
int i = 0;
for (Cluster cluster : model.getClusters()) {
i++;
writer.write("==================== classification " + i + " ====================\n");
for (Point point : cluster.getPoints()) {
writer.write(point.toString() + "\n");
}
writer.write("\n");
writer.write("centroid is " + cluster.getCentroid().toString());
writer.write("\n\n");
}

writer.close();

}

}

2.最终生成的模型类,也就是最终训练好的结果。K值,计算的点距离类型以及object function value值。package com.hyr.kmeans;

import java.util.ArrayList;
import java.util.List;

public class KmeansModel {

private List<Cluster> clusters = new ArrayList<Cluster>();
private Double ofv;
private int k; // k值
private int calc_distance_type;

public KmeansModel(List<Cluster> clusters, Double ofv, int k, int calc_distance_type) {
this.clusters = clusters;
this.ofv = ofv;
this.k = k;
this.calc_distance_type = calc_distance_type;
}

public List<Cluster> getClusters() {
return clusters;
}

public Double getOfv() {
return ofv;
}

public int getK() {
return k;
}

public int getCalc_distance_type() {
return calc_distance_type;
}
}
3.数据集点对象,包含点的维度,代码里只给出了x轴,y轴二维。以及点的距离计算。通过类型选择距离公式。给出了几种常用的距离公式。package com.hyr.kmeans;

public class Point {

private Float x; // x 轴
private Float y; // y 轴

public Point(Float x, Float y) {
this.x = x;
this.y = y;
}

public Float getX() {
return x;
}

public void setX(Float x) {
this.x = x;
}

public Float getY() {
return y;
}

public void setY(Float y) {
this.y = y;
}

@Override
public String toString() {
return "Point{" +
"x=" + x +
", y=" + y +
'}';
}

/**
* 计算距离
*
* @param centroid 质心点
* @param type
* @return
*/
public Double calculateDistance(Point centroid, int type) {
// TODO
Double result = null;
switch (type) {
case 1:
result = calcL1Distance(centroid);
break;
case 2:
result = calcCanberraDistance(centroid);
break;
case 3:
result = calcEuclidianDistance(centroid);
break;
}
return result;
}

/*
计算距离公式
*/

private Double calcL1Distance(Point centroid) {
double res = 0;
res = Math.abs(getX() - centroid.getX()) + Math.abs(getY() - centroid.getY());
return res / (double) 2;
}

private double calcEuclidianDistance(Point centroid) {
return Math.sqrt(Math.pow((centroid.getX() - getX()), 2) + Math.pow((centroid.getY() - getY()), 2));
}

private double calcCanberraDistance(Point centroid) {
double res = 0;
res = Math.abs(getX() - centroid.getX()) / (Math.abs(getX()) + Math.abs(centroid.getX()))
+ Math.abs(getY() - centroid.getY()) / (Math.abs(getY()) + Math.abs(centroid.getY()));
return res / (double) 2;
}

@Override
public boolean equals(Object obj) {
Point other = (Point) obj;
if (getX().equals(other.getX()) && getY().equals(other.getY())) {
return true;
}
return false;
}
}
4.训练后最终得到的分类。包含该分类的质点,属于该分类的点集合该分类是否收敛。package com.hyr.kmeans;

import java.util.ArrayList;
import java.util.List;

public class Cluster {

private List<Point> points = new ArrayList<Point>(); // 属于该分类的点集
private Point centroid; // 该分类的中心质点
private boolean isConvergence = false;

public Point getCentroid() {
return centroid;
}

public void setCentroid(Point centroid) {
this.centroid = centroid;
}

@Override
public String toString() {
return centroid.toString();
}

public List<Point> getPoints() {
return points;
}

public void setPoints(List<Point> points) {
this.points = points;
}

public void initPoint() {
points.clear();
}

public boolean isConvergence() {
return isConvergence;
}

public void setConvergence(boolean convergence) {
isConvergence = convergence;
}
}
5.K-Meams训练类。按照上面所说四个步骤不断进行训练。package com.hyr.kmeans;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;

public class Kmeans {

/**
* kmeans
*
* @param points 数据集
* @param k K值
* @param k 计算距离方式
*/
public static KmeansModel run(List<Point> points, int k, int type) {
// 初始化质心点
List<Cluster> clusters = initCentroides(points, k);

while (!checkConvergence(clusters)) { // 所有分类是否全部收敛
// 1.计算距离对每个点进行分类
// 2.判断质心点是否改变,未改变则该分类已经收敛
// 3.重新生成质心点
initClusters(clusters); // 重置分类中的点
classifyPoint(points, clusters, type);// 计算距离进行分类
recalcularCentroides(clusters); // 重新计算质心点
}

// 计算目标函数值
Double ofv = calcularObjetiFuncionValue(clusters);

KmeansModel kmeansModel = new KmeansModel(clusters, ofv, k, type);

return kmeansModel;
}

/**
* 初始化k个质心点
*
* @param points 点集
* @param k K值
* @return 分类集合对象
*/
private static List<Cluster> initCentroides(List<Point> points, Integer k) {
List<Cluster> centroides = new ArrayList<Cluster>();

// 求出数据集的范围(找出所有点的x最小、最大和y最小、最大坐标。)
Float max_X = Float.NEGATIVE_INFINITY;
Float max_Y = Float.NEGATIVE_INFINITY;
Float min_X = Float.POSITIVE_INFINITY;
Float min_Y = Float.POSITIVE_INFINITY;
for (Point point : points) {
max_X = max_X < point.getX() ? point.getX() : max_X;
max_Y = max_Y < point.getY() ? point.getY() : max_Y;
min_X = min_X > point.getX() ? point.getX() : min_X;
min_Y = min_Y > point.getY() ? point.getY() : min_Y;
}
System.out.println("min_X" + min_X + ",max_X:" + max_X + ",min_Y" + min_Y + ",max_Y" + max_Y);

// 在范围内随机初始化k个质心点
Random random = new Random();
// 随机初始化k个中心点
for (int i = 0; i < k; i++) {
float x = random.nextFloat() * (max_X - min_X) + min_X;
float y = random.nextFloat() * (max_Y - min_Y) + min_X;
Cluster c = new Cluster();
Point centroide = new Point(x, y); // 初始化的随机中心点
c.setCentroid(centroide);
centroides.add(c);
}

return centroides;
}

/**
* 重新计算质心点
*
* @param clusters
*/
private static void recalcularCentroides(List<Cluster> clusters) {
for (Cluster c : clusters) {
if (c.getPoints().isEmpty()) {
c.setConvergence(true);
continue;
}

// 求均值,作为新的质心点
Float x;
Float y;
Float sum_x = 0f;
Float sum_y = 0f;
for (Point point : c.getPoints()) {
sum_x += point.getX();
sum_y += point.getY();
}
x = sum_x / c.getPoints().size();
y = sum_y / c.getPoints().size();
Point nuevoCentroide = new Point(x, y); // 新的质心点

if (nuevoCentroide.equals(c.getCentroid())) { // 如果质心点不再改变 则该分类已经收敛
c.setConvergence(true);
} else {
c.setCentroid(nuevoCentroide);
}
}
}

/**
* 计算距离,对点集进行分类
*
* @param points 点集
* @param clusters 分类
* @param type 计算距离方式
*/
private static void classifyPoint(List<Point> points, List<Cluster> clusters, int type) {
for (Point point : points) {
Cluster masCercano = clusters.get(0); // 该点计算距离后所属的分类
Double minDistancia = Double.MAX_VALUE; // 最小距离
for (Cluster cluster : clusters) {
Double distancia = point.calculateDistance(cluster.getCentroid(), type); // 点和每个分类质心点的距离
if (minDistancia > distancia) { // 得到该点和k个质心点最小的距离
minDistancia = distancia;
masCercano = cluster; // 得到该点的分类
}
}
masCercano.getPoints().add(point); // 将该点添加到距离最近的分类中
}
}

private static void initClusters(List<Cluster> clusters) {
for (Cluster cluster : clusters) {
cluster.initPoint();
}
}

/**
* 检查收敛
*
* @param clusters
* @return
*/
private static boolean checkConvergence(List<Cluster> clusters) {
for (Cluster cluster : clusters) {
if (!cluster.isConvergence()) {
return false;
}
}
return true;
}

/**
* 计算目标函数值
*
* @param clusters
* @return
*/
private static Double calcularObjetiFuncionValue(List<Cluster> clusters) {
Double ofv = 0d;

for (Cluster cluster : clusters) {
for (Point point : cluster.getPoints()) {
int type = 1;
ofv += point.calculateDistance(cluster.getCentroid(), type);
}
}

return ofv;
}
}

最终训练结果:==================== K is 6 , Object Funcion Value is 21.82857036590576 , calc_distance_type is 3 ====================
==================== classification 1 ====================
Point{x=3.5, y=12.5}

centroid is Point{x=3.5, y=12.5}

==================== classification 2 ====================
Point{x=6.8, y=12.6}
Point{x=7.8, y=12.2}
Point{x=8.2, y=11.1}
Point{x=9.6, y=11.1}

centroid is Point{x=8.1, y=11.75}

==================== classification 3 ====================
Point{x=4.4, y=6.5}
Point{x=4.8, y=1.1}
Point{x=5.3, y=6.4}
Point{x=6.6, y=7.7}
Point{x=8.2, y=4.5}
Point{x=8.4, y=6.9}
Point{x=9.0, y=3.4}

centroid is Point{x=6.671428, y=5.2142863}

==================== classification 4 ====================
Point{x=6.0, y=19.9}
Point{x=6.2, y=18.5}
Point{x=5.3, y=19.4}
Point{x=7.6, y=17.4}

centroid is Point{x=6.275, y=18.800001}

==================== classification 5 ====================
Point{x=0.8, y=9.8}
Point{x=1.2, y=11.6}
Point{x=2.8, y=9.6}
Point{x=3.8, y=9.9}

centroid is Point{x=2.15, y=10.225}

==================== classification 6 ====================
Point{x=6.1, y=14.3}

centroid is Point{x=6.1, y=14.3}



代码下载地址: http://download.csdn.net/download/huangyueranbbc/10267041 github:  https://github.com/huangyueranbbc/KmeansDemo 
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息