您的位置:首页 > 其它

K近邻法及手写数字识别系统(二)

2017-09-25 19:28 288 查看
         在K近邻法及约会网站预测系统(一),简单的介绍了K近邻法的定义和三要素,并使用K近邻法实现了简单的约会网站预测系统。本文在此基础上,构建识别数字0~9的手写数字识别系统。

        手写数字识别系统的训练集中每个数字大约有200个样本,一共约2000个数据,测试数据集包括大概900个数据。其中每个数据为32*32的二进制图像矩阵,如下图表示数字0。为了复用前面网站约会系统的分类函数,需要将矩阵转换为1*1024的向量。通过img2vector()函数创建一个1*1024的Numpy数组,循环读取给定文件的前32行,并将每行的前32个字符存入数组,返回数组。



def img2vector(filename):          #将32*32的二进制矩阵转换为1*1024的向量
returnVect = zeros((1,1024))   #构建1*1024数组
fr = open(filename)
for i in range(32):            #读取前32行
lineStr = fr.readline()
for j in range(32):        #读取前32个字符
returnVect[0,32*i+j] = int(lineStr[j])    #存入数组
return returnVect


def handwriterClassTest():            #手写数字识别系统
hwLabels = []                     #存储数据类别
trainingFileList = listdir('trainingDigits')    #给定目录下的所有文件名,需要从os模块导入listdir函数,"trainingDigits":训练数据集的文件夹
m = len(trainingFileList)                       #训练数据集长度
trainingMat = zeros((m,1024))                   #训练数据集的矩阵
for i in range(m):                              #遍历所有数据
fileNameStr = trainingFileList[i]           #获得当前文件,文件名为0_0.txt  第一个0代表数字类别,第二个0代表当前类别的序号
fileStr = fileNameStr.split('.')[0]         #截取文件名  如:0_0.txt ->  0_0
classNumStr = int(fileStr.split('_')[0])    #从文件名获取类别
hwLabels.append(classNumStr)                #将当前类别加入数组
trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr)  #将当前文件中的32*32矩阵转换成1*1024向量
testFileList = listdir('testDigits')            #测试数据集   "testDigits"为测试数据集的文件夹
errorCount = 0.0                                #记录错误数量
mTest = len(testFileList)                       #测试数据集数量
for i in range(mTest):                          #循环遍历测试数据
fileNameStr = testFileList[i]
fileStr = fileNameStr.split('.')[0]
classNumStr = int(fileStr.split('_')[0])
vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)    #这四行代码和上面的训练数据集的处理方法一致
classifierResult = classify0(vectorUnderTest,trainingMat,hwLabels,3)    #使用上一篇中的分类函数进行分类
print('the classify came back with %d,the real answer is:%d' % (classifierResult,classNumStr))   #打印每条数据的实际类别和预测类别
if (classifierResult != classNumStr):       #如果不相等,errorCount加1
errorCount += 1.0
print('the total number of errors is :%d' % errorCount)            #打印错误数
print('the total error rate is:%f' % (errorCount/float(mTest)))    #打印错误率



          可以看出,错误数为11个,错误率约为0.01。

end
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: