KNN实现手写数字识别
2019-08-15 17:12
232 查看
版权声明:本文为博主原创文章,遵循 CC 4.0 by-sa 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/weixin_41643663/article/details/99644916
import operator import datetime import numpy as np from numpy import * from os import listdir from skimage import io print('程序处理的图片大小,建议不要超过200*200\n') N = int(input('需要处理的图片的大小(100至200),N=')) #N = 120 # 图片大小:N*N color = 100 / 255 # 灰度阈值 #KNN算法主体 def KNN(test_data,train_data,train_label,k): #已知分类的数据集(训练集)的行数 dataSetSize = train_data.shape[0] #求所有距离:tile函数将输入点拓展成与训练集相同维数的矩阵,并计算测试样本与每一个训练样本的距离 all_distances = np.sqrt(np.sum(np.square(tile(test_data,(dataSetSize,1))-train_data),axis=1)) #按all_distances中元素进行升序排序后得到其对应索引的列表 sort_distance_index = all_distances.argsort() #选择距离最小的k个点 all_predictive_value = {} for i in range(k): #返回最小距离的训练集的索引(预测值) predictive_value = train_label[sort_distance_index[i]] print('第',i+1,'次预测值',predictive_value) all_predictive_value[predictive_value] = all_predictive_value.get(predictive_value,0)+1 #求众数:按classCount字典的第2个元素(即类别出现的次数)从大到小排序 sorted_class_count = sorted(all_predictive_value.items(), key = operator.itemgetter(1), reverse = True) return sorted_class_count[0][0] #训练集:得到训练集数据矩阵、下标签索引 def get_all_train_data(): train_label = [] train_file_list = listdir('trainlist') #获取目录内容 m = len(train_file_list) #m维向量的训练集 #get_train_data函数:得到所有训练集图像的向量矩阵 train_data = get_all_data(train_file_list,1) for i in range(m): file_name = train_file_list[i] #fileNameStr:所有训练集文件名 train_label.append(get_number_cut(file_name)) #得到训练集下标 return train_label,train_data #得到所有训练集/测试集的向量矩阵(k=1训练集传入;k=0测试集传入) def get_all_data(file_list,k): train_data = np.zeros([len(file_list), N**2]) #enumerate函数用于遍历序列中的元素以及它们的下标(i是下标,item是元素信息) for i, item in enumerate(file_list): if k == 1: #训练集:读取图片并转为灰度值(黑字体为0,白底为255) img = io.imread('./trainlist/'+ item, as_grey = True) else: #测试集:读取图片并转为灰度值(黑字体为0,白底为255) img = io.imread('./testlist/' + item, as_grey = True) #降噪处理 img[img>color] = 1 #将图片进行切割,保留有值的部分 img = get_cut_picture(img) #将图片进行拉伸,得到需求大小:N*N img = get_stretch_picture(img).reshape(N**2) #将处理后的图片信息存入矩阵 train_data[i, 0:N**2] = img #若将图片的真实值存入矩阵(需要存入图片索引,上面语句改train_data = np.zeros([len(file_list), N**2+1]) #train_data[i, N**2] = float(item[0]) return train_data #切割图象 def get_cut_picture(img): #初始化新大小 size = [] #图片的行数 length = len(img) #图片的列数 width = len(img[0,:]) #计算新大小 size.append(get_edge(img, length, 0, [-1, -1])) size.append(get_edge(img, width, 1, [-1, -1])) size = np.array(size).reshape(4) #print('图像尺寸(高低左右):',size) return img[size[0]:size[1]+1, size[2]:size[3]+1] #获取切割边缘(高低左右的索引) def get_edge(img, length, flag, size): for i in range(length): #判断是行是列 if flag == 0: #正序判断该行是否有手写数字 line1 = img[img[i,:]<color] #倒序判断该行是否有手写数字 line2 = img[img[length-1-i,:]<color] else: line1 = img[img[:,i]<color] line2 = img[img[:,length-1-i]<color] #若有手写数字,即到达边界,记录下行 if len(line1)>=1 and size[0]==-1: size[0] = i if len(line2)>=1 and size[1]==-1: size[1] = length-1-i #若上下边界都得到,则跳出 if size[0]!=-1 and size[1]!=-1: break return size #拉伸图像 def get_stretch_picture(img): newImg = np.ones(N**2).reshape(N, N) newImg1 = np.ones(N ** 2).reshape(N, N) #对每一行/列进行拉伸/压缩 #每一行拉伸/压缩的步长 step1 = len(img[0])/N #每一列拉伸/压缩的步长 step2 = len(img)/N #对每一行进行操作 for i in range(len(img)): for j in range(N): newImg[i, j] = img[i, int(np.floor(j*step1))] #对每一列进行操作 for i in range(N): for j in range(N): newImg1[j, i] = newImg[int(np.floor(j*step2)), i] return newImg1 #从文件名中分解出第一个数字(真实值) def get_number_cut(file_name): fileStr = file_name.split('.')[0] #文件名格式为:0_3.txt classNumStr = int(fileStr.split('_')[0]) return classNumStr #用字符矩阵打印图片 def get_show(test_data): for i in range(N**2): if(test_data[0,i] == 0): print ("1",end='') else: print ("0",end='') if (i+1)%N == 0 : print() def main(): t1 = datetime.datetime.now() # 计时开始 Nearest_Neighbor_number = int(input('选取最邻近的K个值(建议小于7),K=')) #训练集:get_train_data()函数得到训练集数据矩阵、下标签索引 train_label, train_data = get_all_train_data() #测试集:根据路径,获取测试集地址 test_file_list = listdir('testlist') file_name = test_file_list[0] #测试集:运用切片函数,得到测试集下标索引(真实值) test_index = get_number_cut(file_name) #测试集:得到训练集图像的向量矩阵 test_data = get_all_data(test_file_list,0) #测试集:get_show()函数:用字符矩阵打印图片 #get_show(test_data) #调用knn算法进行测试 Result = KNN(test_data, train_data, train_label, Nearest_Neighbor_number) print ("最终预测值为:",Result," 真实值:",test_index) t2 = datetime.datetime.now() print('耗 时 = ', t2 - t1) if __name__ == "__main__": main()
相关文章推荐
- 利用scikit-learn下的knn实现kaggle的手写数字识别问题
- opencv+KNN实现手写简单数字识别
- 各种机器学习方法(线性回归、支持向量机、决策树、朴素贝叶斯、KNN算法、逻辑回归)实现手写数字识别并用准确率、召回率、F1进行评估
- 机器学习入门-用KNN实现手写数字图片识别(包含自己图片转化)
- 【机器学习】Knn算法实现手写数字识别
- SVM和Knn实现手写数字识别
- 学习KNN(二)KNN算法手写数字识别的OpenCV实现
- knn算法实现的数字手写识别
- Python实现knn算法手写数字识别
- python3与机器学习实践---2、KNN实现手写数字识别
- 手写数字识别(2)——KNN实现
- 使用Knn算法实现手写数字识别系统(附带jpg转txt代码)
- 学习KNN(三)KNN+HOG实现手写数字识别
- KNN分类算法实现手写数字识别
- MachineLearning— (KNN)k Nearest Neighbor实现手写数字识别(三)
- 机器学习-kNN实现简单的手写数字识别系统
- KNN实现手写数字识别
- Matlab实现手写数字识别(PCA+KNN)
- 学习笔记——《机器学习实战》KNN算法实现 约会网站测试,手写数字识别,代码,注释,错误修改
- 用kNN实现MNIST手写数字识别