KNN python实践
2016-03-23 18:55
411 查看
本文实现了一个KNN算法,准备用作词频统计改进版本之中,这篇博文是从我另一个刚开的博客中copy过来的。
KNN算法是一个简单的分类算法,它的动机特别简单:与一个样本点距离近的其他样本点绝大部分属于什么类别,这个样本就属于什么类别,算法的主要步骤如下:
1.计算新样本点与已知类别数据集中样本点的距离。
2.取前K个距离最近的(最相似的)点。
3.统计这K个点所在类别出现的频率。
4.选择出现频率最高的点作为新样本点的类别。
KNN算法的优点在于一般精度高,对于异常的噪音数据不敏感。KNN一个明显的问题是当属于某个类别c的数据点在已知类别数据集中大量存在时,一个待预测的样本点的前K个最近的点总是存在很多类别c的点,解决这个问题的方法是计算类别的频率时,按照距离进行加权,使得离得近的点比离的远一些点更能影响类别频率排序的结果。
KNN算法中K值的选定非常影响最后的结果,通常可以使用交叉检验来选取合适的k。下面是仿照sikit-learn的KNeighborsClassifier的调用方式写的KNN:
测试代码如下所示
这里使用machine learning in action中的提供的dating data,将90%的数据用作训练数据集,10%的数据用作测试集,选取k=50算法得到的错误率为0.08。
下面我们来看一下如何使用scikit-learn提供的KNN实现。
scikit-learn中主要提供了2种KNN,KNeighborsClassifier和RadiusNeighborsClassifier。前者使用指定的前K个近邻来预测新样本点的类别,后者则是根据一个指定的半径,使用半径内所有的点来预测。创建一个KNN分类器时有这些重要的参数:
n_neighbors/radius: 使用近邻的个数K或半径
algorithm: 实现KNN的具体算法,如kd树等
metric: 距离的计算方法,默认为'minkowski'表示minkowski距离
p: minkowski距离中的参数p,p=1表示manhattan distance(l1范数),p=2表示euclidean_distance (l2范数)
这里只列出了几个常用的参数,具体的可以参考链接。使用的方法和上面的测试代码类似,只需要将classifier替换成scikit-learn的实现就可以了。
KNN算法是一个简单的分类算法,它的动机特别简单:与一个样本点距离近的其他样本点绝大部分属于什么类别,这个样本就属于什么类别,算法的主要步骤如下:
1.计算新样本点与已知类别数据集中样本点的距离。
2.取前K个距离最近的(最相似的)点。
3.统计这K个点所在类别出现的频率。
4.选择出现频率最高的点作为新样本点的类别。
KNN算法的优点在于一般精度高,对于异常的噪音数据不敏感。KNN一个明显的问题是当属于某个类别c的数据点在已知类别数据集中大量存在时,一个待预测的样本点的前K个最近的点总是存在很多类别c的点,解决这个问题的方法是计算类别的频率时,按照距离进行加权,使得离得近的点比离的远一些点更能影响类别频率排序的结果。
KNN算法中K值的选定非常影响最后的结果,通常可以使用交叉检验来选取合适的k。下面是仿照sikit-learn的KNeighborsClassifier的调用方式写的KNN:
class KNN_Classifier: def __init__(self, k): self.k = k self.train_data = None self.train_labels = None def fit(self, train_data, train_labels): self.train_data = normalize(train_data) self.train_labels = train_labels def predict(self, test_data): if (self.train_data is None) | (self.train_labels is None): print 'fit train data first!' pre_labels = [] train_data_size = len(self.train_labels) # for every data point in test set for x in normalize(test_data): # calculate distance sq_diff_mat = (np.tile(x, (train_data_size, 1)) - self.train_data) ** 2 distances = np.sum(sq_diff_mat, axis=1) ** .5 # get lowest k distances sorted_dis_indices = distances.argsort()[0: self.k] # count the times class occur class_counts = {} for idx in sorted_dis_indices: label = labels[idx] class_counts[label] = class_counts.get(label, 0) + 1 # sort class_count dict sorted_class = sorted(class_counts.items(), key=lambda d: d[1], reverse=True) # add max voted class to pre_labels pre_labels.append(sorted_class[0][0]) return pre_labels
测试代码如下所示
# load data data, labels = load_dating_data() # split data into train set and test set split_pos = int(len(labels) * 0.9) train_data = normalize(data[0: split_pos]) train_labels = labels[0: split_pos] test_data = normalize(data[split_pos: len(labels)]) test_labels = labels[split_pos: len(labels)] # init classifier classifier = KNN_Classifier(50) # fit classifier classifier.fit(train_data, train_labels) # predict the class of test data and count error points error_num = (test_labels != classifier.predict(test_data)).sum() # calculate error rate and print print 'error rate is %f' % (error_num * 1.0 / len(test_labels))
这里使用machine learning in action中的提供的dating data,将90%的数据用作训练数据集,10%的数据用作测试集,选取k=50算法得到的错误率为0.08。
下面我们来看一下如何使用scikit-learn提供的KNN实现。
scikit-learn中主要提供了2种KNN,KNeighborsClassifier和RadiusNeighborsClassifier。前者使用指定的前K个近邻来预测新样本点的类别,后者则是根据一个指定的半径,使用半径内所有的点来预测。创建一个KNN分类器时有这些重要的参数:
n_neighbors/radius: 使用近邻的个数K或半径
algorithm: 实现KNN的具体算法,如kd树等
metric: 距离的计算方法,默认为'minkowski'表示minkowski距离
p: minkowski距离中的参数p,p=1表示manhattan distance(l1范数),p=2表示euclidean_distance (l2范数)
这里只列出了几个常用的参数,具体的可以参考链接。使用的方法和上面的测试代码类似,只需要将classifier替换成scikit-learn的实现就可以了。
相关文章推荐
- 在Python里安装Jieba中文分词组件
- Pythonj~module
- think python学习心得-(3)条件和递归
- Python文档生成工具pydoc使用介绍
- 初学者必知的Python中优雅的用法
- python编辑器pydev安装
- python+selenium环境搭建
- [python]int函数带参用法
- python转成window可执行.exe文件
- Python爬取鬼吹灯2(周建龙)(PyV8解析js)
- Python爬虫利器二之Beautiful Soup的用法
- A Byte of Python (1)安装和运行
- Selenium+WebDriver+Python 定时控制任务
- Python os.path
- python paramiko模块使用介绍
- numpy、scipy、matplotlib安装与配置
- python之psutil模块获取系统信息
- python类、对象、方法、属性之类与对象笔记
- zookeeper python接口
- python 赋值、深浅拷贝、作用域