K近邻法及手写数字识别系统(二)
2017-09-25 19:28
288 查看
在K近邻法及约会网站预测系统(一),简单的介绍了K近邻法的定义和三要素,并使用K近邻法实现了简单的约会网站预测系统。本文在此基础上,构建识别数字0~9的手写数字识别系统。
手写数字识别系统的训练集中每个数字大约有200个样本,一共约2000个数据,测试数据集包括大概900个数据。其中每个数据为32*32的二进制图像矩阵,如下图表示数字0。为了复用前面网站约会系统的分类函数,需要将矩阵转换为1*1024的向量。通过img2vector()函数创建一个1*1024的Numpy数组,循环读取给定文件的前32行,并将每行的前32个字符存入数组,返回数组。
可以看出,错误数为11个,错误率约为0.01。
end
手写数字识别系统的训练集中每个数字大约有200个样本,一共约2000个数据,测试数据集包括大概900个数据。其中每个数据为32*32的二进制图像矩阵,如下图表示数字0。为了复用前面网站约会系统的分类函数,需要将矩阵转换为1*1024的向量。通过img2vector()函数创建一个1*1024的Numpy数组,循环读取给定文件的前32行,并将每行的前32个字符存入数组,返回数组。
def img2vector(filename): #将32*32的二进制矩阵转换为1*1024的向量 returnVect = zeros((1,1024)) #构建1*1024数组 fr = open(filename) for i in range(32): #读取前32行 lineStr = fr.readline() for j in range(32): #读取前32个字符 returnVect[0,32*i+j] = int(lineStr[j]) #存入数组 return returnVect
def handwriterClassTest(): #手写数字识别系统 hwLabels = [] #存储数据类别 trainingFileList = listdir('trainingDigits') #给定目录下的所有文件名,需要从os模块导入listdir函数,"trainingDigits":训练数据集的文件夹 m = len(trainingFileList) #训练数据集长度 trainingMat = zeros((m,1024)) #训练数据集的矩阵 for i in range(m): #遍历所有数据 fileNameStr = trainingFileList[i] #获得当前文件,文件名为0_0.txt 第一个0代表数字类别,第二个0代表当前类别的序号 fileStr = fileNameStr.split('.')[0] #截取文件名 如:0_0.txt -> 0_0 classNumStr = int(fileStr.split('_')[0]) #从文件名获取类别 hwLabels.append(classNumStr) #将当前类别加入数组 trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr) #将当前文件中的32*32矩阵转换成1*1024向量 testFileList = listdir('testDigits') #测试数据集 "testDigits"为测试数据集的文件夹 errorCount = 0.0 #记录错误数量 mTest = len(testFileList) #测试数据集数量 for i in range(mTest): #循环遍历测试数据 fileNameStr = testFileList[i] fileStr = fileNameStr.split('.')[0] classNumStr = int(fileStr.split('_')[0]) vectorUnderTest = img2vector('testDigits/%s' % fileNameStr) #这四行代码和上面的训练数据集的处理方法一致 classifierResult = classify0(vectorUnderTest,trainingMat,hwLabels,3) #使用上一篇中的分类函数进行分类 print('the classify came back with %d,the real answer is:%d' % (classifierResult,classNumStr)) #打印每条数据的实际类别和预测类别 if (classifierResult != classNumStr): #如果不相等,errorCount加1 errorCount += 1.0 print('the total number of errors is :%d' % errorCount) #打印错误数 print('the total error rate is:%f' % (errorCount/float(mTest))) #打印错误率
可以看出,错误数为11个,错误率约为0.01。
end
相关文章推荐
- caffe中如何训练自己的手写数字识别系统?
- 深度学习与神经网络实战:快速构建一个基于神经网络的手写数字识别系统
- Python实现识别手写数字 简易图片存储管理系统
- [毕业设计]手写数字识别系统设计与实现
- 机器学习-kNN实现简单的手写数字识别系统
- 手写数字识别系统之数字提取
- 使用Knn算法实现手写数字识别系统(附带jpg转txt代码)
- 手写数字识别系统之图像分割
- 手写数字识别系统之倾斜矫正
- 【《机器学习实战》第2章读书笔记】手写数字识别系统剖析
- Python(TensorFlow框架)实现手写数字识别系统
- 【作业】手写数字识别系统
- 机器学习--knn手写数字识别系统
- 机器学习实战之程序清单1-kNN(手写数字识别系统)
- 手写数字识别系统编程技巧
- 机器学习实战之k-近邻算法(6)---手写数字识别系统(0-9识别)
- k-近邻算法实现手写数字识别系统
- k-近邻算法实现手写数字识别系统
- 手写数字识别系统之倾斜矫正
- Python(TensorFlow框架)实现手写数字识别系统的方法