您的位置:首页 > 编程语言 > Java开发

实际项目中以java面向对象的方式实现K-means算法,把对象聚类

2017-01-18 10:56 417 查看
k-means算法接受输入量
k ;然后将n个数据对象划分为 k个聚类以便使得所获得的聚类满足:同一聚类中的对象相似度较高;而不同聚类中的对象相似度较小。聚类相似度是利用各聚类中对象的均值所获得一个“中心对象”(引力中心)来进行计算的。

一,k-means算法介绍:

k-means算法接受输入量 k ;然后将n个数据对象划分为 k个聚类以便使得所获得的聚类满足:同一聚类中的对象相似度较高;而不同聚类中的对象相似度较小。聚类相似度是利用各聚类中对象的均值所获得一个“中心对象”(引力中心)来进行计算的。k个聚类具有以下特点:各聚类本身尽可能的紧凑,而各聚类之间尽可能的分开。 k个聚类具有以下特点:各聚类本身尽可能的紧凑,而各聚类之间尽可能的分开。

k-means算法的工作过程说明如下:首先从n个数据对象任意选择 k 个对象作为初始聚类中心;而对于所剩下其它对象,则根据它们与这些聚类中心的相似度(距离),分别将它们分配给与其最相似的(聚类中心所代表的)聚类;然后再计算每个所获新聚类的聚类中心(该聚类中所有对象的均值);不断重复这一过程直到标准测度函数开始收敛为止。一般都采用均方差作为标准测度函数。k个聚类具有以下特点:各聚类本身尽可能的紧凑,而各聚类之间尽可能的分开。

二,k-means算法基本步骤:

(1) 从 n个数据对象任意选择 k 个对象作为初始聚类中心;

(2) 根据每个聚类对象的均值(中心对象),计算每个对象与这些中心对象的距离;并根据最小距离重新对相应对象进行划分;

(3) 重新计算每个(有变化)聚类的均值(中心对象);

(4) 计算标准测度函数,当满足一定条件,如函数收敛时,则算法终止;如果条件不满足则回到步骤(2),不断重复直到标准测度函数开始收敛为止。(一般都采用均方差作为标准测度函数。)

三,k-means算法的Java实现:

一共有七个类,General.java代表武将对象, Distance.java距离类计算各个武将到中心武将之间的距离, Cluster.java聚类对象包含一个中心武将和该聚类中所有武将, Kmeans.java核心的聚类算法类, Tool.java工具类用于转换武将的星级为数字等操作, TestKmeans.java测试类即入口文件,
DomParser.java用于读取xml中的681个武将。

具体思路:先从general.xml文件中读取681个武将,然后随机选取初始类中心,计算各个武将到中心武将的距离,根据最小的距离进行聚类,然后重新根据平均值新的聚类的类中心,重新计算各个武将到新的中心武将的距离,直到更新后的聚类与原来的聚类包含的武将不再改变,即收敛时结束

代码如下:

(一)实体类对象:

import lombok.Data;

/**
* 专题分析-各行政区天然气数据实体类
*
* @author liuhai
* @create 2016-12-27 下午 7:23
**/
@Data
public class TA_GXZQLNTRQSJModel {
private int infoId;//信息id
private String province;//省份
private float reserves;//储量
private String years;//年份
private float consumption;//消耗量
private float production;//生产量
}


(二)K-means算法核心

package com.GasBigData.dataMining.service;/**
* Created by Liu海 on 2017/1/16 0016.
*/

import com.GasBigData.thematicAnalysis.model.TA_GXZQLNTRQSJModel;

import java.util.ArrayList;
import java.util.Random;

/**
* 聚类分析算法
*
* @author
* @create 2017-01-16 下午 7:03
**/
public class K_means {
private int k;// 分成多少簇
private int m;// 迭代次数
private int dataSetLength;// 数据集元素个数,即数据集的长度
private ArrayList<TA_GXZQLNTRQSJModel> dataSet;// 数据集链表
private ArrayList<TA_GXZQLNTRQSJModel> center;// 中心链表
private ArrayList<ArrayList<TA_GXZQLNTRQSJModel>> cluster; // 簇
private ArrayList<Float> jc;// 误差平方和,k越接近dataSetLength,误差越小
private Random random;

/**
* 设置需分组的原始数据集
*
* @param dataSet
*/

public void setDataSet(ArrayList<TA_GXZQLNTRQSJModel> dataSet) {
this.dataSet = dataSet;
}

/**
* 获取结果分组
*
* @return 结果集
*/

public ArrayList<ArrayList<TA_GXZQLNTRQSJModel>> getCluster() {
return cluster;
}

/**
* 构造函数,传入需要分成的簇数量
*
* @param k
* 簇数量,若k<=0时,设置为1,若k大于数据源的长度时,置为数据源的长度
*/
public K_means(int k) {
if (k <= 0) {
k = 1;
}
this.k = k;
}

/**
* 初始化
*/
private void init() {
m = 0;
random = new Random();
/*if (dataSet == null || dataSet.size() == 0) {//如果数据源为空,调用自备数据源
initDataSet();
}*/
dataSetLength = dataSet.size();//数据源长度
if (k > dataSetLength) {//k值不能超过数据源长度
k = dataSetLength;
}
center = initCenters();//初始化中心数据链
cluster = initCluster();//空簇
jc = new ArrayList<Float>();
}

/**
* 如果调用者未初始化数据集,则采用内部测试数据集
*/
/*private void initDataSet() {
dataSet = new ArrayList<float[]>();
// 其中{6,3}是一样的,所以长度为15的数据集分成14簇和15簇的误差都为0
float[][] dataSetArray = new float[][] { { 8, 2 }, { 3, 4 }, { 2, 5 },
{ 4, 2 }, { 7, 3 }, { 6, 2 }, { 4, 7 }, { 6, 3 }, { 5, 3 },
{ 6, 3 }, { 6, 9 }, { 1, 6 }, { 3, 9 }, { 4, 1 }, { 8, 6 } };

for (int i = 0; i < dataSetArray.length; i++) {
dataSet.add(dataSetArray[i]);
}
}
*/
/**
* 初始化中心数据链表,分成多少簇就有多少个中心点
*
* @return 中心点集
*/
private ArrayList<TA_GXZQLNTRQSJModel> initCenters() {
ArrayList<TA_GXZQLNTRQSJModel> center = new ArrayList<TA_GXZQLNTRQSJModel>();
int[] randoms = new int[k];
boolean flag;
int temp = random.nextInt(dataSetLength);//产生0-dataSetLength之间的伪随机数
randoms[0] = temp;
for (int i = 1; i < k; i++) {
flag = true;
while (flag) {
temp = random.nextInt(dataSetLength);
int j = 0;
// 不清楚for循环导致j无法加1
// for(j=0;j<i;++j)
// {
// if(temp==randoms[j]);
// {
// break;
// }
// }
while (j < i) {
if (temp == randoms[j]) {
break;
}
j++;
}
if (j == i) {
flag = false;
}
}
randoms[i] = temp;
}

//测试随机数生成情况
for(int i=0;i<k;i++)
{
System.out.println("test1:randoms["+i+"]="+randoms[i]);
}

// System.out.println();
for (int i = 0; i < k; i++) {
center.add(dataSet.get(randoms[i]));// 生成初始化中心链表
}
return center;
}

/**
* 初始化簇集合
*
* @return 一个分为k簇的空数据的簇集合
*/
private ArrayList<ArrayList<TA_GXZQLNTRQSJModel>> initCluster() {
ArrayList<ArrayList<TA_GXZQLNTRQSJModel>> cluster = new ArrayList<ArrayList<TA_GXZQLNTRQSJModel>>();
for (int i = 0; i < k; i++) {
cluster.add(new ArrayList<TA_GXZQLNTRQSJModel>());
}

return cluster;
}

/**
* 计算两个点之间的距离
*
* @param element
*            点1
* @param center
*            点2
* @return 距离
*/
private float distance(TA_GXZQLNTRQSJModel element, TA_GXZQLNTRQSJModel center) {
float distance = 0.0f;
float x = element.getConsumption() - center.getConsumption();
float z = x * x ;
distance = (float) Math.sqrt(z);//返回z的平方根

return distance;
}

/**
* 获取距离集合中最小距离的位置
*
* @param distance
*            距离数组
* @return 最小距离在距离数组中的位置
*/
private int minDistance(float[] distance) {
float minDistance = distance[0];
int minLocation = 0;
for (int i = 1; i < distance.length; i++) {//排序方法进行排序求得最小距离
if (distance[i] < minDistance) {
minDistance = distance[i];
minLocation = i;
} else if (distance[i] == minDistance) // 如果相等,随机返回一个位置
{
if (random.nextInt(10) < 5) {
minLocation = i;
}
}
}

return minLocation;
}

/**
* 核心,将当前元素放到最小距离中心相关的簇中
*/
private void clusterSet() {
float[] distance = new float[k];
for (int i = 0; i < dataSetLength; i++) {
for (int j = 0; j < k; j++) {
distance[j] = distance(dataSet.get(i), center.get(j));//返回两点间距离
// System.out.println("test2:"+"dataSet["+i+"],center["+j+"],distance="+distance[j]);

}
int minLocation = minDistance(distance);//最小距离在距离数组中的位置
// System.out.println("test3:"+"dataSet["+i+"],minLocation="+minLocation);
// System.out.println();

cluster.get(minLocation).add(dataSet.get(i));// 核心,将当前元素放到最小距离中心相关的簇中

}
}

/**
* 求两点误差平方的方法
*
* @param element
*            点1
* @param center
*            点2
* @return 误差平方
*/
private float errorSquare(TA_GXZQLNTRQSJModel element, TA_GXZQLNTRQSJModel center) {
float x = element.getConsumption() - center.getConsumption();

float errSquare = x * x ;

return errSquare;
}

/**
* 计算误差平方和准则函数方法
*/
private void countRule() {
float jcF = 0;
for (int i = 0; i < cluster.size(); i++) {
for (int j = 0; j < cluster.get(i).size(); j++) {
jcF += errorSquare(cluster.get(i).get(j), center.get(i));

}
}
jc.add(jcF);
}

/**
* 设置新的簇中心方法
*/
private void setNewCenter() {
for (int i = 0; i < k; i++) {
int n = cluster.get(i).size();
if (n != 0) {
float newCenter = 0;
TA_GXZQLNTRQSJModel mod = new TA_GXZQLNTRQSJModel();
for (int j = 0; j < n; j++) {
newCenter += cluster.get(i).get(j).getConsumption();
}
// 设置一个平均值
newCenter = newCenter / n;
mod.setConsumption(newCenter);
center.set(i, mod);
}
}
}

/**
* 打印数据,测试用
*
* @param dataArray
*            数据集
* @param dataArrayName
*            数据集名称
*/
public void printDataArray(ArrayList<TA_GXZQLNTRQSJModel> dataArray,
String dataArrayName) {
for (int i = 0; i < dataArray.size(); i++) {
System.out.println("print:" + dataArrayName + "[" + i + "]={"
+ dataArray.get(i).getConsumption() + "," + dataArray.get(i).getProvince() + "}");
}
System.out.println("===================================");
}

/**
* Kmeans算法核心过程方法
*/
private void kmeans() {
init();
// printDataArray(dataSet,"initDataSet");
// printDataArray(center,"initCenter");

// 循环分组,直到误差不变为止
while (true) {
clusterSet();//给簇赋值,把相关的数据放到相关的簇中
// for(int i=0;i<cluster.size();i++)
// {
// printDataArray(cluster.get(i),"cluster["+i+"]");
// }

countRule();

// System.out.println("count:"+"jc["+m+"]="+jc.get(m));

// System.out.println();
// 误差不变了,分组完成
if (m != 0) {
if (jc.get(m) - jc.get(m - 1) == 0) {
break;//跳出循环
}
}

setNewCenter();//仍然有误差,设置新的簇中心,继续
// printDataArray(center,"newCenter");
m++;
cluster.clear();
cluster = initCluster();
}

// System.out.println("note:the times of repeat:m="+m);//输出迭代次数
}

/**
* 执行算法
*/
public void execute() {
long startTime = System.currentTimeMillis();
System.out.println("kmeans begins");
kmeans();
long endTime = System.currentTimeMillis();
System.out.println("kmeans running time=" + (endTime - startTime)
+ "ms");
System.out.println("kmeans ends");
System.out.println();
}
}
(三)调用函数

/**
* 各行政区天然气消费聚类分析
*/
public void gxzqtrqxfJLFX(){
ArrayList<TA_GXZQLNTRQSJModel> dataSet = new ArrayList<TA_GXZQLNTRQSJModel>();
dataSet = reservesAnalyseMapper.jlfxTest();
//初始化一个Kmean对象,将k置为10
K_means k=new K_means(4);
k.setDataSet(dataSet);
//执行算法
k.execute();
//得到聚类结果
ArrayList<ArrayList<TA_GXZQLNTRQSJModel>> cluster=k.getCluster();
//查看结果
for(int i=0;i<cluster.size();i++)
{
k.printDataArray(cluster.get(i), "cluster["+i+"]");
}
}


(4)聚类结果

[] [2017-01-18 10:30:16] [DEBUG] ==>  Preparing: SELECT * FROM TA_GXZQLNTRQSJ ss GROUP BY ss.PROVINCE 

[] [2017-01-18 10:30:16] [DEBUG] ==> Parameters: 

[] [2017-01-18 10:30:16] [DEBUG] <==      Total: 31

[] [2017-01-18 10:30:16] [DEBUG] Releasing transactional SqlSession [org.apache.ibatis.session.defaults.DefaultSqlSession@418edfd]

kmeans begins

test1:randoms[0]=30

test1:randoms[1]=4

test1:randoms[2]=19

test1:randoms[3]=1

kmeans running time=3ms

kmeans ends

print:cluster[0][0]={72.43,上海}

print:cluster[0][1]={74.96,山东}

print:cluster[0][2]={76.87,河南}

print:cluster[0][3]={78.16,浙江}

print:cluster[0][4]={84.0,辽宁}

print:cluster[0][5]={82.15,重庆}

print:cluster[0][6]={74.26,陕西}

===================================

print:cluster[1][0]={44.53,内蒙古}

print:cluster[1][1]={45.49,天津}

print:cluster[1][2]={34.46,安徽}

print:cluster[1][3]={50.35,山西}

print:cluster[1][4]={56.08,河北}

print:cluster[1][5]={46.0,海南}

print:cluster[1][6]={40.24,湖北}

print:cluster[1][7]={50.26,福建}

print:cluster[1][8]={40.59,青海}

print:cluster[1][9]={35.48,黑龙江}

===================================

print:cluster[2][0]={113.7,北京}

print:cluster[2][1]={165.17,四川}

print:cluster[2][2]={133.83,广东}

print:cluster[2][3]={169.87,新疆}

print:cluster[2][4]={127.7,江苏}

===================================

print:cluster[3][0]={4.63,云南}

print:cluster[3][1]={22.58,吉林}

print:cluster[3][2]={17.88,宁夏}

print:cluster[3][3]={8.25,广西}

print:cluster[3][4]={15.19,江西}

print:cluster[3][5]={24.4,湖南}

print:cluster[3][6]={25.2,甘肃}

print:cluster[3][7]={0.0,西藏}

print:cluster[3][8]={10.62,贵州}

===================================

数据是从数据库查出来的,如有问题请各位大神指正!
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签:  java 算法 面向对象
相关文章推荐