您的位置:首页 > 编程语言 > Python开发

《机器学习实战》(1)——kNN算法

2013-08-06 14:12 190 查看
学习《机器学习实战》,对python语言不怎么熟悉,决定一段程序一段程序来学习,既学习算法,也顺便学习python的基础知识。最后,我将把python代码用Java重写一遍。
第一个算法kNN(k-邻近算法)
这个算法的理论很简单,很容易理解。如果学过KMeans聚类算法,那么学这个算法会感觉更简单。
我对这个算法的过程理解如下:
第一步:把所有的训练集读入到内存中,这也是这个算法为什么会有空间复杂度高的原因了。
第二步:读入待分类的向量(如果是文本,要处理成向量的方式,VSM模型在这里起作用了)
第三步:计算待分类向量到所有训练集的距离。(既然是向量计算距离,一般用欧式距离就OK了)第四步:对距离进行从小到大排序,取前k个训练集的Label。第五步:对前K个训练集的Label进行统计。把待分类向量分到Label个数最多的那一个类别。第六步:算法结束。
学习了过程,再来学习代码的实现。算法的过程了解后,就很容易得出需要输入的参数:待分类文本,训练集,训练集标签,k值。输出的参数:待分类文本的分类标签。输入输出解决了的话,至少解决了问题的三分之一。
def classify(inX,dataSet,labels,k):
//
dataSetSize=dataSet.shape[0]
diffMat=tile(inX,(dataSetSize,1))-dataSet
sqDiffMat=diffMat**2
sqDistances=sqDiffMat.sum(axis=1)
distances=sqDistances**0.5
sortedDistIndicies=distances.argsort();
classCount={}
for i in range(k):
voteIlabel=labels[sortedDistIndicies[i]]
classCount[voteIlabel]=classCount.get(voteIlabel,0)+1
sortedClassCount=sorted(classCount.iteritems(),
key=operator.itemgetter(1),reverse=True)
return sortedClassCount[0][0]
把python版的翻译成java版的:
package com.vancl.knn;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Map;
import java.util.Map.Entry;
public class KNN {
/*
* @param inX 待分类的文本
* @param dataSet 训练集
* @param labels 训练集的分类标签
* @param k值
* @return 分类器得到的分类标签
* */
public char classify(double[] inX,double[][] dataSet,char[] labels,int k){
//对应python 的dataSet.shape()[0]
int dataSetSize=dataSet.length;
//对应python 的tile(inX,(dataSetSize,1))-dataSet
//和diffMat**2 两行代码
double[][] sqDiffMat=createDiffMat(inX,dataSet,dataSetSize);
//对应python 的sqdiffMat.sum(axis=1);
//和distances=sqDistances**0.5两行代码
double[] distances=sum(sqDiffMat);

Node[] disNode=new Node[distances.length];
for(int i=0;i<distances.length;i++){
Node node=new Node(distances[i],i);
disNode[i]=node;
}
//对应python中 sortedDistaIndicies=distances.argsort(),排序得到下标
Arrays.sort(disNode,new KNNCompartor());
//选择距离最小的k个点
//对应pyhton的classCount={}
Map<Character,Integer> classCount=new HashMap<Character,Integer>();

char voteLabel;
//对应python的 for i in range(k):
for(int i=0;i<k;i++){
//对应python的voteIlabel=labels[sortedDistIndicies[i]]
voteLabel=labels[disNode[i].idx];
//对应 classCount[voteIlabel]=classCount.get(voteIlabel,0)+1
add(voteLabel,classCount);
}
//sortedClassCount=sorted(classCount.iteritems(),
//       key=operator.itemgetter(1),reverse=True)
ArrayList<Map.Entry<Character,Integer>> l = new ArrayList<Map.Entry<Character,Integer>>(classCount.entrySet());
Collections.sort(l,new Comparator<Map.Entry<Character,Integer>>(){
@Override
public int compare(Entry<Character, Integer> o1,
Entry<Character, Integer> o2) {

return o2.getValue()-o1.getValue();
}
});
//对应 return sortedClassCount[0][0]
return l.get(0).getKey();
}

public void add(char voteLabel,Map<Character,Integer> classCount){
Integer id=classCount.get(voteLabel);
if(id==null) id=0;
classCount.put(voteLabel, id+1);
}

private double[] sum(double[][] sqDiffMat) {
int i,j;
double[] sqDistances=new double[sqDiffMat.length];
for(i=0;i<sqDiffMat.length;i++){
sqDistances[i]=0;
for(j=0;j<sqDiffMat[i].length;j++){
sqDistances[i]+=sqDiffMat[i][j];
}
sqDistances[i]=Math.sqrt(sqDistances[i]);
}
return sqDistances;
}
private double[][] createDiffMat(double[] inX, double[][] dataSet,int dataSetSize) {
double[][] diffMat=new double[dataSetSize][inX.length];
for(int i=0;i<dataSetSize;i++){
System.arraycopy(inX, 0, diffMat[i], 0, inX.length);
for(int j=0;j<inX.length;j++){
diffMat[i][j]=diffMat[i][j]-dataSet[i][j];
diffMat[i][j]=Math.pow(diffMat[i][j], 2);
}
}

return diffMat;
}
class Node{
public Node(double value, int idx) {
super();
this.value = value;
this.idx = idx;
}
double value;
int idx;
}
class KNNCompartor implements Comparator<Node>{
@Override
public int compare(Node o1, Node o2) {
return Double.compare(o1.value, o2.value);
}

}
public static void main(String[] args) {
KNN knn=new KNN();
double[] inX={1,1};
double[][] dataSet ={{1,1.1},{1,1},{0,0},{0,0.1}};
char[]labels={'A','A','B','B'};
char rs=knn.classify(inX, dataSet, labels,3 );
System.out.println(rs);
}
}
至此,明白为什么机器学习的书籍为什么大都选择python,而不选择java了。
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息