您的位置:首页 > 编程语言 > Python开发

数据分类K—means 算法的python代码实现

2016-04-05 19:10 561 查看
k_means算法是用来进行数据分类的,尤其适用于相同维度数据集合的分类。

形象来说,以2维平面为例,原始数据则可以看做是分布在其中的一些点,而分类则是需要找出一些中心点(可能不属于这些数据)将这些数据点分成一个个集合,如此便完成了分类。

算法主要步骤为:

1) 选择 K 个起始的中心点

2) 按照欧拉距离最小的原则,将原始数据分类到这K个中心点形成的集合中

3) 重新计算这K 个中心点的坐标,它的新坐标为它所形成的集合中所有点各维度的算术平均值所形成的新点

4) 以重新形成的K 个中心点 重复2),3), 直至中心点坐标不再变化

算法的思想很简单直接,但是有一些问题需要指出:

1) 如何选择最初的K 个中心点?

在实现中大多数都采用 随机选择的方法,当然也可以人为指定。如果使用随机选取的方法,会造成最终的结果形成波动

2) 如何定 K 值的大小?

这个只能凭借经验了,而且和数据的分布特点也有关系,很难确定出最好的方案

PS: 注意在程序实现过程中,数据可能是Int 型,但是由于新的中心点需要算术平均,所以注意进行类型转换。

运行方式:

将代码保存为.py格式,默认使用的数据是代码文件所在目录下data目录下的 k_means.txt 文件分别作为源数据输入。以上参数可以在源代码中修改,也可以使用命令行参数传入,参考以下启动方式:

python k_means.py k_means.txt 3
命令中后参数为输入数据的途径 和 自行提供的 K 值(可以省略此参数,默认K=3)。

实验数据来源:
http://blog.csdn.net/androidlushangderen/article/details/43373159
特别感谢。

python 源代码如下:

__author__ = 'Administrator'
import re
import sys
import random
import copy

k = 3
dt = []
lg = wd = 0

def K_means():
ctr = [[0 for j in range(wd)] for i in range(k)]
st = set()
for i in range(k):
n=-1
while(1):
n = random.randint(0,lg-1)
st.add(n)
if len(st)>i:
break
ctr[i] = copy.deepcopy(dt
)
# print ctr
flg = 0
blg = []
ls = [0 for i in range(lg)]
blg.append(ls)
x=0
while(1):
# print ctr
x+=1
flg=0
# print "x: ",x
ls = [0 for i in range(lg)]

for i in range(lg):
# print dt[i],ctr[0]
mn=getDis(dt[i],ctr[0])
l=0
# print "dis: ",mn
for j in range(1,k):
a = getDis(dt[i],ctr[j])
# print "d:",a
if mn>a:
mn=a
l=j
ls[i] = l
blg.append(ls)
# print "ls:",ls
for i in range(k):
n = 0
su = [0 for aa in range(wd)]
for j in range(lg):
if ls[j]==i:
for a in range(wd):
su[a] += dt[j][a]
n+=1
for a in range(wd):
ctr[i][a] = float(su[a])/n
for i in range(lg):
if blg[x][i]!=blg[x-1][i]:
flg=1
if flg==0:
break
#print ctr#,blg
print "After %d times....\n"%(x+1)
return ctr ,blg[x]

def getDis(a,b):
dis = 0
w = len(a)
for i  in range(w):
dis += (a[i]-b[i])*(a[i]-b[i])
return dis

if __name__ == '__main__':
data = "data/k_means.txt"
if  len(sys.argv)>1:
data = sys.argv[1]
if len(sys.argv)>2:
k = int(sys.argv[2])
fp=open(data,"r")
for line in fp:
line = re.sub(r"\n\r","",line)
ls = line.split()
wd = len(ls)
la = ls
for i in range(wd):
la[i] = int(ls[i])
dt.append(la)
lg = len(dt) # ,dt[1],dt[1][0]==4
print "data:\n",dt
ls,la = K_means()
print "center:\n",ls
print "blongs:\n",la
fp.close()
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息