您的位置:首页 > 其它

《机器学习实战》--KNN

2016-02-09 01:08 267 查看

一、瞎扯

先拉一下仇恨,这篇文章是在喝着走亲戚时带来的饮料,单曲循环着成龙版本的《拯救》的情况下完成的,哈哈,过年难免有些活的太潇洒,于是还是需要写些代码,看些书来收收心,另外新的一年开始了,也该对“懒”开刀了,准备养成写博客这一及其装逼的技能,祝各位同行新年快乐!(开始写的的时候还是大年初一,发布的时候过期了可别怪我)

KNN是机器学习中最简单,最基础的算法之一,算法实现起来没什么难度,但是它的使用范围依然十分广泛,比如书中提到的电影的分类,婚姻网站的配偶分类,手写识别系统,和轨迹预处理中选择候选轨迹等。


[/code]

二、KNN介绍

   2.1监督学习和非监督学习
监督学习:利用一组已知类别的样本调整分类器的参数,使其达到所要求性能的过程,也称为监督训练或有教师学习
举个例子,教婴儿学习的时候,过来一只鸡,就让他叫鸡,过来一只鸭子的时候就让他叫鸭子,教会了一会,随便找一些鸡和鸭子他就能分辨了。这就是监督学习

无监督学习:其中很重要的一类叫聚类
    举个例子,过来两种动物(假设还是鸡和鸭子,但是婴儿不知道),然后他根据动物的相像程度把动物分成两群

 2.2欧几里得距离
这个在论文中经常能见到,它是一个通常采用的距离定义,指在m维空间中两个点之间的真实距离,或者向量的自然长度(即该点到原点的距离)。在二维和三维空间中的欧氏距离就是两点之间的实际距离。

n维下的两个点的欧氏距离:
  两个点 A = (a[1],a[2],…,a
) 和 B = (b[1],b[2],…,b
) 之间的距离 ρ(A,B)
  定义为下面的公式:ρ(A,B) =√ [ ∑( a[i] - b[i] )^2 ] (i = 1,2,…,n)

  2维下就是最常见的距离公式了 sqrt( (x1-x2)^2+(y1-y2)^2 )了

 2.3KNN算法介绍
KNN属于监督学习,他需要一个训练集来训练,然后才能对后面给出的东西(测试集)进行分类。



   通俗的过程:
1.通过训练集训练它,如上图,红色和蓝色就是训练集,他们的类别是已知的,于是现在来了一个未知的东西
(标为绿色),希望通过knn来给它分类,看应该是属于蓝色的还是红色的。
2.计算所有红色点和蓝色点到绿点的距离
3.排序找到最近的k个,上图中实线表示k取3(即取距离最近的三个点),同理虚线表示k取5
4.这k个里面哪个类别的东西多,就判断这个绿色点是属于哪一类的。如果上图看实线,红的比蓝的多那么
绿点画为红的,如果是虚线,则标为蓝的

三、代码

代码我自己用了C++风格的python写了一遍,也就是说不用矩阵计算,那些矩阵被我当作多维数组来用了,代码写的很粗糙,未优化,望见谅。

另外还要再说明一个知识点–归一化

这个很常见,数学课上也讲过就是把数值缩小到0-1之间,这里是把训练集中的样本数据归一化,使用的公式是newValue=(oldValue-min)/(max-min)

1  # -*- coding: UTF-8 -*-
2 from numpy import *
3 from math import *
4
5
6 #定义类
7 class Student(object):
8
9     def __init__(self, distance, label):
10         self.distance = distance
11         self.label = label
12
13 #KNN算法
14 def classify0(inX, dataSet, labels, k):
15     size=shape(dataSet)
16     line=size[0]
17     column=size[1]
18     # print (inX)
19     if(len(inX)!=column):
20         print("unequal!!")
21         return
22
23 #计算当前项inX与其余训练集的欧氏距离
24     sum=0.0000000000
25     disList=[]
26     for i in range(line):
27         for j in range(column):
28             sum+=(inX[j]-dataSet[i,j])**2
29         tmp=Student(sqrt(sum),labels[i])
30         disList.append(tmp)
31         sum=0.0000000000
32 #排序:欧氏距离从小到大
33     disList.sort(lambda x,y:cmp(x.distance,y.distance))
34
35 #取k项判断分类
36     dict={}
37     index=0
38     for item in disList:
39         if(index==k):
40             break
41         index+=1
42         if(dict.has_key(item.label)):
43             dict[item.label]+=1
44         else:
45             dict[item.label]=1
46
47     dict=sorted(dict.iteritems(),key=lambda d:d[1],reverse=True)
48     #返回最可能的值
49     return dict[0][0]
50
51 # def classify0(inX, dataSet, labels, k):
52 #     dataSetSize = dataSet.shape[0]
53 #     diffMat = tile(inX, (dataSetSize,1)) - dataSet
54 #     sqDiffMat = diffMat**2
55 #     sqDistances = sqDiffMat.sum(axis=1)
56 #     distances = sqDistances**0.5
57 #     sortedDistIndicies = distances.argsort()
58 #     classCount={}
59 #     for i in range(k):
60 #         voteIlabel = labels[sortedDistIndicies[i]]
61 #         classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
62 #     sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
63 #     print sortedClassCount[0][0]
64 #     return sortedClassCount[0][0]
65
66
67 #读取数据
68 def file2matrix (filename):
69     fr=open(filename)
70     lines=fr.readlines()
71     totalCnt=len(lines)
72     totalCol=len(lines[0].strip().split('\t'))
73
74     resultMatrix=zeros((totalCnt,totalCol-1))
75     labelMatrix=[]
76     index=0
77     for line in lines:
78         tmp=line.strip().split('\t')
79         resultMatrix[index,:]=tmp[0:totalCol-1]
80         labelMatrix.append((tmp[-1]))
81         index+=1
82     return resultMatrix,labelMatrix
83
84
85
86 # def file2matrix(filename):
87 #     fr=open(filename)
88 #     arrayOlines=fr.readlines()
89 #     numberOfLines=len(arrayOlines)
90 #     returnMat=zeros((numberOfLines,3))
91 #     classLabelVector=[]
92 #     index=0
93 #     for line in arrayOlines:
94 #         line=line.strip()
95 #         listFromLine=line.split('\t')
96 #         returnMat[index,:]=listFromLine[0:3];
97 #         classLabelVector.append((listFromLine[-1]))
98 #         index+=1
99 #     return returnMat,classLabelVector
100
101
102 #归一化处理  newValue=(oldValue-min)/(max-min)
103 def autoNorm(dataSet):
104     size=shape(dataSet)
105     line=size[0]
106     column=size[1]
107
108     min=zeros((1,column))
109     max=zeros((1,column))
110     # print min
111     index=0
112     for value in dataSet[0,:]:
113         min[0,index]=value
114         max[0,index]=value
115         index+=1
116 #求每一列的最小值和最大值
117     for i in range(line):
118         for j in range(column):
119             if(i==0 and j==0):
120                 continue
121             if(dataSet[i,j]>max[0,j]):
122                 max[0,j]=dataSet[i,j]
123             if(dataSet[i,j]<min[0,j]):
124                 min[0,j]=dataSet[i,j]
125
126     ranges = max-min
127     # print ranges
128     result=zeros((line,column))
129
130     for i in range(line):
131         for j in range(column):
132             result[i,j]=(dataSet[i,j]-min[0,j])/ranges[0,j]
133
134     return result,ranges,min
135
136
137 # def autoNorm(dataSet):
138 #     minVals = dataSet.min(0)
139 #     maxVals = dataSet.max(0)
140 #     ranges = maxVals - minVals
141 #     print ranges
142 #     normDataSet = zeros(shape(dataSet))
143 #     m = dataSet.shape[0]
144 #     normDataSet = dataSet - tile(minVals, (m,1))
145 #     normDataSet = normDataSet/tile(ranges, (m,1))   #element wise divide
146 #     return normDataSet, ranges, minVals
147
148
149
150 def datingClassTest():
151     hoRatio = 0.50      #hold out 10%
152     datingDataMat,datingLabels = file2matrix('datingTestSet2.txt')       #load data setfrom file
153     normMat, ranges, minVals = autoNorm(datingDataMat)
154     m = normMat.shape[0]
155     numTestVecs = int(m*hoRatio)
156     errorCount = 0.0
157     for i in range(numTestVecs):
158         classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],3)
159         print "the classifier came back with: %s, the real answer is: %s" % (classifierResult, datingLabels[i])
160         if (classifierResult != datingLabels[i]): errorCount += 1.0
161     print "the total error rate is: %f" % (errorCount/float(numTestVecs))
162     print errorCount
163
164
165 datingClassTest()


四、代码以及测试数据的下载

github:   https://github.com/wlmnzf/Machine-Learning-train/tree/master/KNN

五、感谢

1.《机器学习实战》这本书写的不错,值得学习
2.百度百科提供的图片和解释也不能忘
3.[《什么是无监督学习》 知乎](http://www.zhihu.com/question/23194489)
4.感谢 网易云音乐 《拯救》-成龙  深夜相伴


扫码或者搜索 “会打代码的扫地王大爷” 关注公众号

内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: