您的位置:首页 > 编程语言 > Python开发

【十四】机器学习之路——K-近邻算法实战

2017-11-23 23:19 148 查看
使用k-近邻算法识别手写数字

  一个星期没有更新博客了,最近在看K-近邻算法和决策树,学习《机器学习实战》K-近邻算法里的实战问题代码时遇到了些问题,经过几天的硬啃,终于完成了代码。话不多说,下面一起看一下如何用K-近邻算法实现识别手写数字。[例子与代码摘自《机器学习实战》]

  简单起见,该算法只能识别0~9的数字。这里识别算法首先,咱们将数字的图像使用图形处理软件,处理成相同大小:宽高均为32像素的黑白图像。图像存储为文本格式。如下图所示:







  现在咱们手头有训练集TrainingDigits大约2000个例子,0~9中每个数字大约200个左右的例子,测试集TestDigits大约有900个左右的例子,0~9每个数字大约100个;[数据取自《机器学习实战 第2章 k-近邻算法》],每个数据命名格式如下图所示:(后面代码里会根据这个命名格式来读取相应的数字)



  OK,数据准备好了,现在可以大干一场了,我们怎么利用K近邻算法来实现数字识别呢?还记得K-近邻算法的思路吗?如果忘记的同学可以参考上一篇博客机器学习之路——k-近邻算法(KNN)。在这个数字识别问题里同样,我们的处理思路如下:

1. 首先计算测试集里的数据与训练集里数据的距离差。

2. 计算好测试点与训练集里样本点的距离后,将结果从小到大进行排序;

3. 选取距离最近的k个点,确定这k个点数据所在分类的出现频率;

4. 选择频率最高的分类作为预测数据的分类输出;

其实以上四个步骤咱们上一篇博客里已经定义了函数实现了:

def classify0(inX, dataSet, labels, k):


  咱们在这个实战问题里其实重点就是怎么把这些文本里的数据变成classify0()函数可以处理的数据。

  之前classify0()函数里计算距离的思路是将测试数据一个1*2的矩阵数组(x,y)利用tile扩展成m*2的数组(这里m为训练集数据总数),然后同训练集进行相减求平方再开根号得到距离,同样这里我们首先先将每个测试集样本数据先转化为一个1*n的矩阵数组(n代表特征值个数)。先看这段代码如何实现:

#定义数据处理的函数,将训练集转化为1*1024的矩阵
def img2vector(filename):
returnVect = zeros((1,1024))#先定义一个空矩阵数组,下面里用for循环将测试样本读入该数组中
fr = open(filename)
for i in range (32):#因为每个数字样本在文本里是32*32的矩阵数组,一行一行来赋值,所以需要两个for循环嵌套
lineStr = fr.readline()#readline()依次读取每一行,readlines()是输出文件共多少行,注意区分
for j in range(32):#读取第一行后,将第一行里的32个数据赋值给returnVect前32个数,依次类推,最终将32*32=1024个数据全部赋值完毕
returnVect[0,32*i+j] = int(lineStr[j])
return returnVect#返回最终的数据


  完成了测试集数据处理后,现在就要处理训练集数据了,思路相同,将所有训练集数据放入一个m*1024的矩阵数组里,m为训练集样本个数,然后对应的将每个样本的分类Labels放入一个1*m的矩阵数组里,【如果这点不太懂的话可以参考我上一篇KNN算法介绍的博客,链接在上面】完成训练集的处理后,就可以利用classify0()函数对测试集数据进行分类了。好了,一起看下这段代码怎么实现。代码有点长,我会一句一句的注释,让大家更容易理解,涉及到相关的python内置函数后面有对应的链接供学习。

#定义手写数字识别函数,并计算其错误率
def handwritingClassTest():
hwLabels = []#将训练集的数据对应的Label即数字用一个list容器存储,classify0()函数输入要用到
trainingFileList = listdir('这里填训练集所在文件夹地址') #输出文件夹里所有训练集的文件名称与后缀,用于读取训练集里每个文件对应的数字即Label
m = len(trainingFileList)#计算下训练集共有多少组数据
trainingMat = zeros((m,1024))#构造一个m*1024的矩阵数组存储训练集里的所有数据,每一行是一组数据
for i in range(m):#通过for循环将m组数据对应的label赋值到hwLabels中去
fileNameStr = trainingFileList[i]
fileStr = fileNameStr.split('.')[0]
classNumStr = int(fileStr.split('_')[0])
hwLabels.append(classNumStr)
trainingMat[i,:] = img2vector('这里填训练集所在文件夹地址%s' % fileNameStr)#将训练集里数据赋值到trainingMat
testFileList = listdir('这里填测试集所在文件夹地址')
errorCount = 0.0#识别函数识别结果是错误的个数
mTest = len(testFileList)#计算测试集数据个数
#以下代码将mTest个测试集里的数据依次进行识别,并输出识别的结果
for i in range(mTest):
fileNameStr = testFileList[i]
fileStr = fileNameStr.split('.')[0]
classNumStr = int(fileStr.split('_')[0])#读取测试集数据对应的真实数字是多少即测试集数据的label
vectorUnderTest = img2vector('这里填测试集所在文件夹地址%s' % fileNameStr)#将测试集里数据进行img2vector转换
classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)#利用分类函数来进行手写数据识别,并输出识别的结果,和对应的实际结果
print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr)
if (classifierResult != classNumStr): errorCount += 1.0
print "\nthe total number of errors is: %d" % errorCount
print "\nthe total error rate is: %f" % (errorCount/float(mTest))#计算错误率并输出


最终代码运行的结果如下所示,由于用了两个for循环嵌套,导致运行的时间有点久。测试集共946组数据,识别错误11个,错误率1.1628%。



  以上介绍的就是利用k-近邻算法实现的手写数字识别,代码里涉及到的一些具体函数如下,如有不懂的同学可以点进链接进行学习。

read()、readline()、readlines()函数区别

listdir()函数用法

split()函数用法

  好了,今天就讲到这里,欢迎大家多多交流,如果这篇博客对你有帮助,请动动手指帮我点个赞,谢谢!
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签:  机器学习 python