您的位置:首页 > 其它

MachineLearning实战03_KNN实现手写数字识别

2020-02-01 16:18 393 查看
import numpy as np
import operator
from matplotlib import pyplot as plt
from os import listdir

#####K-近邻算法识别手写数字
def img2vector(filename):
"""将存储的32*32的二进制图像矩阵转换为1*1024向量

返回一维向量returnVect
"""
returnVect=np.zeros((1,1024))
fr=open(filename)
#按行读取
for i in range(32):
#readline()是每次读取一行
#readlines()是一次性读取整个文件
lineStr=fr.readline()
#将每一行的前32个元素依次添加到returnVect中
for j in range(32):
returnVect[0,32*i+j]=int(lineStr[j])
return returnVect

def classify0(New_data,train_data,labels,k):
"""K-近邻算法

New_data表示需要预测的数据
train_data表示训练样本集
labels表示训练样本集的标签
"""
train_dataSize=train_data.shape[0] #获取训练集行数,即样本个数
diffMat=np.tile(New_data,(train_dataSize,1))-train_data
#tile函数依次将每个输入数据重复train_dataSize次,目的是计算每个输入数据与样本数据的差值
sqdiffMat=diffMat**2
sqDistance=sqdiffMat.sum(axis=1) #按行计算和,并取消二维,即组成一维数据,每个数据代表输入数据与每个样本的距离
distances=sqDistance**0.5
sortedDistIndicies=distances.argsort() #从小到大排序,返回索引值
classCount={ }
for i in range(k):
voteIlabel=labels[sortedDistIndicies[i]] #获取距离最小的前k个样本对应的标签
classCount[voteIlabel]=classCount.get(voteIlabel,0)+1 #对于相同标签进行累加
sortedclassCount=sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
#对字典classCount的值进行从大到小排序,key=1表示对字典的值,Reverse表示从大到小
return sortedclassCount[0][0]

def handwritingClassTest():
hwLabels=[] #训练样本类别标签
trainingFileList=listdir(r"D:\STUDY\MachineLearning\KNN\手写识别系统\答案\digits\trainingDigits")
#导入训练集,函数listdir可以列出给定目录的文件名
m=len(trainingFileList)
trainingMat=np.zeros((m,1024))
for i in range(m):
fileNameStr=trainingFileList[i]
#fileNameStr获得的是每个文件名称,例如“0_0.txt"
fileStr=fileNameStr.split('.')[0]
#去掉".txt",获得0_0,[0]表示获得分割标志前的部分,[1]表示获得分割标志后的部分
classNumStr=int(fileStr.split('_')[0])
hwLabels.append(classNumStr)
trainingMat[i,:]=img2vector(r"D:\STUDY\MachineLearning\KNN\手写识别系统\答案\digits\trainingDigits\%s"%fileNameStr)
#测试样本
testFileList=listdir(r"D:\STUDY\MachineLearning\KNN\手写识别系统\答案\digits\testDigits")
errorCount=0.0
mTest=len(testFileList)
for i in range(mTest):
fileNameStr=testFileList[i]
fileStr=fileNameStr.split('.')[0]
classNumStr=int(fileStr.split('_')[0])#获得测试集实际结果
vectorUnderTest=img2vector(r"D:\STUDY\MachineLearning\KNN\手写识别系统\答案\digits\testDigits\%s"%fileNameStr)
classifierResult=classify0(vectorUnderTest,trainingMat,hwLabels,3)
print("the classifier came back with;%d,the real answer is :%d"%(classifierResult,classNumStr))
if (classNumStr!=classifierResult):
errorCount+=1.0
print("the tatal number of errors is:%d"%errorCount )
print("the total error rate is :%f"%(errorCount/float(mTest)))
handwritingClassTest()
  • 点赞
  • 收藏
  • 分享
  • 文章举报
Liar_Chen 发布了21 篇原创文章 · 获赞 0 · 访问量 480 私信 关注
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: