机器学习-KNN 算法
2015-12-10 12:56
246 查看
K-Nearest Neighbour
一、主要目的在样本空间中,找到与待估计的样本最临近的K个邻居,用这几个邻居的类别来估计待测样本的类别
二、适用性
样本容量比较大的类域的自动分类,而样本容量较小的类域则容易误分。尤其适用于样本分类边界不规则的情况
三、不足
1、当样本不平衡时,比如一个类的样本容量很大,其他类的样本容量很小,输入一个样本的时候,K个临近值中大多数都是大样本容量的那个类,这时可能就会导致分类错误。改进方法是对K临近点进行加权,也就是距离近的点的权值大,距离远的点权值小。
2、计算量较大,每个待分类的样本都要计算它到全部点的距离,根据距离排序才能求得K个临近点,改进方法是:先对已知样本点进行剪辑,事先去除对分类作用不大的样本。
四、算法步骤:
1)、计算已知类别数据集合汇总的点与当前点的距离
2)、按照距离递增次序排序
3)、选取与当前点距离最近的K个点
4)、确定距离最近的前K个点所在类别的出现频率
5)、返回距离最近的前K个点中频率最高的类别作为当前点的预测分类
五、matlab 代码实现
注意:不同的K取值,会影响分类的准确率。
六、数据归一化
newData = (oldData-minValue)/(maxValue-minValue)
七、python 代码实现
调试方式:
八、进阶版
knn算法简单有效,但没有优化的暴力算法效率容易受到瓶颈。如果样本个数为N,特征维度为D,则复杂度以O(N*D)增长。
解决办法:把训练数据构建成K-D tree(k-dimensional tree).搜索速度高达O(D*log(N))。不过当D维度过高,会产生所谓的”维度灾难“,最终效率会降低到与暴力法一样。因此通常D>20以后,最好使用更高效率的Ball-Tree,其时间复杂度为O(D*log(N))
测试代码(利用sklearn 库)
# -*- coding: utf-8 -*-
import numpy as np
from sklearn import neighbors
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import classification_report
from sklearn.cross_validation import train_test_split
import matplotlib.pyplot as plt
''''' 数据读入 '''
data = []
labels = []
with open("DATA-KNN.txt") as ifile:
for line in ifile:
tokens = line.strip().split(' ')
data.append([float(tk) for tk in tokens[:-1]])
labels.append(tokens[-1])
x = np.array(data)
labels = np.array(labels)
y = np.zeros(labels.shape)
''''' 标签转换为0/1 '''
y[labels=='fat']=1
''''' 拆分训练数据与测试数据 '''
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size = 0.2)
''''' 创建网格以方便绘制 '''
h = .01
x_min, x_max = x[:, 0].min() - 0.1, x[:, 0].max() + 0.1
y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
np.arange(y_min, y_max, h))
''''' 训练KNN分类器 '''
clf = neighbors.KNeighborsClassifier(algorithm='kd_tree')
clf.fit(x_train, y_train)
'''''测试结果的打印'''
answer = clf.predict(x)
print(x)
print(answer)
print(y)
print(np.mean( answer == y))
'''''准确率与召回率'''
precision, recall, thresholds = precision_recall_curve(y_train, clf.predict(x_train))
answer = clf.predict_proba(x)[:,1]
print(classification_report(y, answer, target_names = ['thin', 'fat']))
''''' 将整个测试空间的分类结果用不同颜色区分开'''
answer = clf.predict_proba(np.c_[xx.ravel(), yy.ravel()])[:,1]
z = answer.reshape(xx.shape)
plt.contourf(xx, yy, z, cmap=plt.cm.Paired, alpha=0.8)
''''' 绘制训练样本 '''
plt.scatter(x_train[:, 0], x_train[:, 1], c=y_train, cmap=plt.cm.Paired)
plt.xlabel(u'身高')
plt.ylabel(u'体重')
plt.show()
KNN分类器在众多分类算法中属于最简单的之一,需要注意的地方不多。有这几点要说明:
1、KNeighborsClassifier可以设置3种算法:‘brute’,‘kd_tree’,‘ball_tree’。如果不知道用哪个好,设置‘auto’让KNeighborsClassifier自己根据输入去决定。
2、注意统计准确率时,分类器的score返回的是计算正确的比例,而不是R2。R2一般应用于回归问题。
3、本例先根据样本中身高体重的最大最小值,生成了一个密集网格(步长h=0.01),然后将网格中的每一个点都当成测试样本去测试,最后使用contourf函数,使用不同的颜色标注出了胖、廋两类。
容易看到,本例的分类边界,属于相对复杂,但却又与距离呈现明显规则的锯齿形。
这种边界线性函数是难以处理的。而KNN算法处理此类边界问题具有天生的优势,这个数据集达到准确率=0.94算是很优秀的结果了。
相关文章推荐
- 3K工资与8K工资的差距,不仅仅是钱!
- LeetCode - 11. Container With Most Water
- Centos配置80 端口转发
- jQuery-Ajax的一点小经验
- Install Shiny Server in Ubuntu 14.04.1
- 基于第三方WheelView 实现的一个时间选择器
- 笔记本外接显示器鼠标从左边进入
- Install R & RStudio in Ubuntu
- 博客编写客户端分享
- Xadmin 常用插件
- sql基础执行顺序
- swift学习1 基本数据类型
- Struts2框架的搭建以及架构总结
- Quartz入门
- LNMP 常见502 Bad Gateway问题汇总
- 简单的随机生成4个数字验证码的实现
- 算法练习 - 字符串的全排列(字典序排列)
- 安卓theme的设置问题
- Windows无法安装到这个磁盘。请确保在计算机的BIOS菜单中启用了磁盘控制器
- js 在myeclipse中报错