您的位置:首页 > 其它

机器学习之K-近邻算法(kNN算法)

2016-02-26 22:08 561 查看
机器学习之K-近邻算法(kNN算法)
一、概念
k-近邻算法是根据不同特征值之间的距离来进行分类的一种简单的机器学习方法。本文简单介绍下kNN算法,并用其实现手写数字的识别。

工作原理:
存在一个样本数据集合,也称训练样本集,并且样本集中每个数据都存在标签,即我们知道样本集中每个数据与其所属分类的关系。输入没有标签的新数据后,将新数据的每个特征与样本集中数据对应的特征进行比较,然后算法提取样本集中特征最相似数据(最近邻)的分类标签。一般来说,我们只选择样本数据集中前k个最相似的数据,这就是k近邻算法中k的出处(通常k<20)。最后,我们选择k个最相似数据中出现次数最多的分类,作为新数据的分类。

算法一般流程:



算法的使用步骤:



kNN算法的Python实现:

def classify(inX,dataSet,labels,k):

dataSetSize=dataSet.shape[0]

diffMat=tile(inX,(dataSetSize,1))-dataSet
sqDiffMat=diffMat**2
sqDistances=sqDiffMat.sum(axis=1)
distances=sqDistances**0.5
sortedDistIndicies=distances.argsort()
classCount={}
for i in range(k):
voteLabel=labels[sortedDistIndicies[i]]
classCount[voteLabel]=classCount.get(voteLabel,0)+1
sortedClassCount=sorted(classCount,key=operator.itemgetter(0),reverse=True)
return sortedClassCount[0][0]

注意:不同特征的值取值范围不同时,通常应该讲数据进行归一化处理。
归一化0-1之间公式:
newValue=(oldValue-min)/(max-min)
归一化到-1-1之间公式:
newValue=(oldValue-mid)/mid
归一化到0-1之间代码:
def autoNorm(dataSet):

minVals=dataSet.min(0)

maxVals=dataSet.max(0)

ranges=maxVals-minVals

normDataSet=zeros(shape(dataSet))

m=dataSet.shape[0]

normDataSet=dataSet-tile(minVals,(m,1))

normDataSet=normDataSet/tile(ranges,(m,1))

return normDataSet,ranges,minVals

二、特点
优点:精度高、对异常值不敏感、无数据输入假定;
缺点:计算复杂度高、空间复杂度高;
适用数据范围:数值型和标称型。

三、适用场景
由于KNN方法主要靠周围有限的邻近的样本,而不是靠判别类域的方法来确定所属类别的,因此对于类域的交叉或重叠较多的待分样本集来说,KNN方法较其他方法更为适合。

KNN算法不仅可以用于分类,还可以用于回归。通过找出一个样本的k个最近邻居,将这些邻居的属性的平均值赋给该样本,就可以得到该样本的属性。更有用的方法是将不同距离的邻居对该样本产生的影响给予不同的权值(weight),如权值与距离成反比( 较近邻居的权重比较远邻居的权重大。)(一种常见的加权方案是给每个邻居权重赋值为1/ d,其中d是到邻居的距离。这个方案是一个线性插值的推广。)

K近邻算法也适用于连续变量估计,比如适用反距离加权平均多个K近邻点确定测试点的值。该算法的功能有:

从目标区域抽样计算欧式或马氏距离;

在交叉验证后的RMSE基础上选择启发式最优的K邻域;

计算多元k-最近邻居的距离倒数加权平均。

四、应用举例(手写数字识别)
原始图:



经过转换后的向量图:



数据准备:
为了能使用上述的分类器,我们必须将图像格式化处理为一个向量。我们将原始的32*32的二进制图像举证转换成为1*1024的向量。
图像转换为向量的代码如下:
def img2vector(filename):

returnVect = zeros((1,1024))

fr = open(filename)

for i in range(32):

lineStr = fr.readline()

for j in range(32):

returnVect[0,32*i+j] = int(lineStr[j])

return returnVect

测试代码:
def handwritingClassTest():

hwLabels = []

trainingFileList = listdir('trainingDigits') #load the training set

m = len(trainingFileList)

trainingMat = zeros((m,1024))

for i in range(m):

fileNameStr = trainingFileList[i]

fileStr = fileNameStr.split('.')[0] #take off .txt

classNumStr = int(fileStr.split('_')[0])

hwLabels.append(classNumStr)

trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr)

testFileList = listdir('testDigits') #iterate through the test set

errorCount = 0.0

mTest = len(testFileList)

for i in range(mTest):

fileNameStr = testFileList[i]

fileStr = fileNameStr.split('.')[0] #take off .txt

classNumStr = int(fileStr.split('_')[0])

vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)

classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)

print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr)

if (classifierResult != classNumStr): errorCount += 1.0

print "\nthe total number of errors is: %d" % errorCount

print "\nthe total error rate is: %f" % (errorCount/float(mTest))
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: