机器学习实战--k近邻算法(续)
2016-05-17 15:00
423 查看
接续上次的k近邻算法,上一篇博文地址,这里用一个新的实例进行算法的验证。
一个手写数字识别系统,为了处理方便,书中已经将样本训练好,并转化为txt格式方便后续处理。具体格式如下:
这是0的其中一种表示方式。
我们的目标是输入一个类似的数字,系统能够识别出来即可。
————————————————————————————————————————
一、具体算法实现:
————————————————————————————————————————
二、结果分析
错误率在1%左右,说明识别准确率还是挺高的;但是因为每次输入一个样本,都要计算与所有训练样本之间的欧式距离,运算量还是挺大的,速度上稍显不足。
一个手写数字识别系统,为了处理方便,书中已经将样本训练好,并转化为txt格式方便后续处理。具体格式如下:
这是0的其中一种表示方式。
我们的目标是输入一个类似的数字,系统能够识别出来即可。
————————————————————————————————————————
一、具体算法实现:
# coding: UTF-8 import numpy as np import os import operator # 将txt格式的数字转化为1*1024的向量格式 def img2vector(filename): return_vector = np.arange(1024) with open(filename) as f: for i in range(32): line = f.readline() for j in range(32): return_vector[32 * i + j] = int(line[j]) return return_vector # 分类算法实现 def classify(input_vector, trained_mat, class_list, k=3): # 欧式距离计算 rows = trained_mat.shape[0] input_mat = np.tile(input_vector, (rows, 1)) diff_mat = input_mat - trained_mat squ_mat = diff_mat ** 2 sum_mat = squ_mat.sum(axis=1) d = sum_mat ** 0.5 # 根据距离排序,获得排序后的索引 sorted_d = d.argsort() # 创建用来统计某一类标签的字典 class_count = {} for i in xrange(k): class_label = class_list[sorted_d[i]] class_count[class_label] = class_count.get(class_label, 0) + 1 # 根据统计得到的类别数量,进行排序,返回一个包含元组的列表,[(),(),...()] sorted_class = sorted(class_count.iteritems(), key=operator.itemgetter(1), reverse=True) return sorted_class[0][0] # 用N个行向量构成训练矩阵 # 通过txt的命名获得分类 def vector2mat(): training_file_list = os.listdir('./trainingDigits') rows = len(training_file_list) trained_mat = np.zeros((rows, 1024)) class_list = [] for index, each_file in enumerate(training_file_list): digits, _ = each_file.split('.') class_list.append(digits.split('_')[0]) trained_mat[index, :] = img2vector('./trainingDigits/%s' % each_file) print trained_mat return trained_mat, class_list # 系统错误率测试 # 通过另一组test数据作为测试样本输入 def handwriting_test(trained_mat, class_list): test_file_list = os.listdir('./testDigits') test_num = len(test_file_list) err_count = 0.0 for each_file in test_file_list: # 去掉.txt的后缀 digits = each_file.split('.')[0] # 得到已知分类,known_label类型应该与class_list中元素类别一致 known_label = digits.split('_')[0] input_vector = img2vector('./testDigits/%s' % each_file) classify_result = classify(input_vector, trained_mat, class_list) print "Predict:%s\tReal answer:%s\n" % (classify_result, known_label) if known_label != classify_result: err_count += 1.0 print "total err num:%d" % err_count print "err rate:%.2f" % (err_count / (float(test_num))) def main(): trained_mat, class_list = vector2mat() handwriting_test(trained_mat, class_list) if __name__ == '__main__': main()
————————————————————————————————————————
二、结果分析
错误率在1%左右,说明识别准确率还是挺高的;但是因为每次输入一个样本,都要计算与所有训练样本之间的欧式距离,运算量还是挺大的,速度上稍显不足。
相关文章推荐
- Python动态类型的学习---引用的理解
- Python3写爬虫(四)多线程实现数据爬取
- 垃圾邮件过滤器 python简单实现
- 下载并遍历 names.txt 文件,输出长度最长的回文人名。
- install and upgrade scrapy
- Scrapy的架构介绍
- Centos6 编译安装Python
- 使用Python生成Excel格式的图片
- 让Python文件也可以当bat文件运行
- [Python]推算数独
- Python中zip()函数用法举例
- Python中map()函数浅析
- Python将excel导入到mysql中
- Python在CAM软件Genesis2000中的应用
- 使用Shiboken为C++和Qt库创建Python绑定
- FREEBASIC 编译可被python调用的dll函数示例
- Python 七步捉虫法