您的位置:首页 > 编程语言 > Python开发

K-nn手写数字识别--Python版

2017-12-13 13:54 399 查看
模式识别的实验作业,弄了一个晚上终于在第二天中午弄明白了!

简单来说,k-nn就是通过计算训练集和 一个测试数据之间的欧式距离,然后将计算结果按照从小到大来排序,找出最小的k个数据,分析k个数据中哪种情况出现的频率最多,那么这个测试数据就属于这一类

思路

读入数据,假设100个训练数据,将训练数据转换为100*1024的二维数组,然后循环读入测试数据,计算测试数据和100个训练数据间的欧式距离:



x1-xn为单个训练数据的所有元素,y1-yn为测试数据的所有元素

这样就得到一个数组,包含所有训练数据和测试数据的欧式距离,将距离从小到大进行排序。

3. 结果

找出k个最近的距离,看哪个数字出现的频率最多,那么这个测试数据大概率为这个数字

#解压文件
def JY():
path="/Users/fanjialiang2401/PycharmProjects/模式识别/digits.zip"
newpath="/Users/fanjialiang2401/PycharmProjects/模式识别/"
f=zipfile.ZipFile(path,'r')
for  file in f.namelist():
f.extract(file,newpath)
print("success!")
#     将32*32矩阵转换为一个长为1024的一位数字
def toVerctor(filename):
returnVect=np.zeros((1,1024))
fr=open(filename)
for i in range(32):
linestr=fr.readline();
for j in range(32):
returnVect[0,32*i+j]=int(linestr[j])
return returnVect;
# 测试 trainlist为训练集所有数据,testdata为测试数据 classLable为
def Classfiy(Trainlist,testdata,classLable,k):
listSize=len(Trainlist)
diffs=[]
for i in range(listSize):
traindata=Trainlist[i];
diffvalue=np.sum(np.square(traindata-testdata))
diff=np.sqrt(diffvalue)
diffs.append(diff)
sortIndex=np.argsort(diffs)
#sortIndex  argsort对所有元素进行排序,返回的是序号值
num=[]
for i in range(10):
num.append(0)
for i in range(k):
num[int(classLable[sortIndex[i]])]+=1;
#    找出出现频率最多的数
s=np.argsort(num)
return s[9]

#读取并且处理文件 相当于main方法 在这里调用其他方法
def Read():
hwlable=[]
# 将读入的数据32*32转换为1024*length的数组
Trainlist=os.listdir('trainingDigits')
length=len(Trainlist)
trainMat=np.zeros((length,1024))

97ab
for i in range (length):
# 读取文件名
filename=Trainlist[i]
filestr=filename.split(".")[0]
#通过字符串分割,得到数字
classNum=filestr.split('_')[0]
hwlable.append(classNum)
trainMat[i:]=toVerctor('trainingDigits/%s'%filename)
# 测试集
# 测试文件 循环比较
testFileList=os.listdir('testDigits')
errorCount=0;
TestLength=len(testFileList)
for i in range(TestLength):
filenamestr=testFileList[i]
filestr=filenamestr.split(".")[0]
classStr=filestr.split("_")[0]
# 测试向量
testVector=toVerctor('testDigits/%s'%filenamestr)
lable=Classfiy(trainMat,testVector,hwlable,5)
if lable!=int(classStr):
errorCount+=1
print('false'+str(lable)+":"+classStr)
print("正确个数:"+str(TestLength-errorCount))
print("正确率:"+str((TestLength-errorCount)/TestLength))


结果:



看的出正确率还挺高的
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: