您的位置:首页 > 编程语言 > Python开发

数据挖掘10大算法(6)-K最近邻(KNN)算法的实现(java和python版)

2015-06-09 10:26 1286 查看
数据挖掘-K最近邻(KNN)算法的实现(java和python版)

KNN算法基础思想前面文章可以参考,这里主要讲解java和python的两种简单实现,也主要是理解简单的思想。

/article/1336017.html

python版本:

这里实现一个手写识别算法,这里只简单识别0~9熟悉,在上篇文章中也展示了手写识别的应用,可以参考:机器学习与数据挖掘-logistic回归及手写识别实例的实现

输入:每个手写数字已经事先处理成32*32的二进制文本,存储为txt文件。0~9每个数字都有10个训练样本,5个测试样本。训练样本集如下图:左边是文件目录,右边是其中一个文件打开显示的结果,看着像1,这里有0~9,每个数字都有是个样本来作为训练集。



第一步:将每个txt文本转化为一个向量,即32*32的数组转化为1*1024的数组,这个1*1024的数组用机器学习的术语来说就是特征向量。

[python] view
plaincopy





<span style="font-size:14px;">def img2vector(filename):

returnVect = zeros((1,1024))

fr = open(filename)

for i in range(32):

lineStr = fr.readline()

for j in range(32):

returnVect[0,32*i+j] = int(lineStr[j])

return returnVect</span>

第二步:训练样本中有10*10个图片,可以合并成一个100*1024的矩阵,每一行对应一个图片,也就是一个txt文档。

[python] view
plaincopy





def handwritingClassTest():

hwLabels = []

trainingFileList = listdir('trainingDigits')

print trainingFileList

m = len(trainingFileList)

trainingMat = zeros((m,1024))

for i in range(m):

fileNameStr = trainingFileList[i]

fileStr = fileNameStr.split('.')[0]

classNumStr = int(fileStr.split('_')[0])

hwLabels.append(classNumStr)

#print hwLabels

#print fileNameStr

trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr)

#print trainingMat[i,:]

#print len(trainingMat[i,:])

testFileList = listdir('testDigits')

errorCount = 0.0

mTest = len(testFileList)

for i in range(mTest):

fileNameStr = testFileList[i]

fileStr = fileNameStr.split('.')[0]

classNumStr = int(fileStr.split('_')[0])

vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)

classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)

print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr)

if (classifierResult != classNumStr): errorCount += 1.0

print "\nthe total number of errors is: %d" % errorCount

print "\nthe total error rate is: %f" % (errorCount/float(mTest))

第三步:测试样本中有10*5个图片,同样的,对于测试图片,将其转化为1*1024的向量,然后计算它与训练样本中各个图片的“距离”(这里两个向量的距离采用欧式距离),然后对距离排序,选出较小的前k个,因为这k个样本来自训练集,是已知其代表的数字的,所以被测试图片所代表的数字就可以确定为这k个中出现次数最多的那个数字。

[python] view
plaincopy





def classify0(inX, dataSet, labels, k):

dataSetSize = dataSet.shape[0]

#tile(A,(m,n))

print dataSet

print "----------------"

print tile(inX, (dataSetSize,1))

print "----------------"

diffMat = tile(inX, (dataSetSize,1)) - dataSet

print diffMat

sqDiffMat = diffMat**2

sqDistances = sqDiffMat.sum(axis=1)

distances = sqDistances**0.5

sortedDistIndicies = distances.argsort()

classCount={}

for i in range(k):

voteIlabel = labels[sortedDistIndicies[i]]

classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1

sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)

return sortedClassCount[0][0]

全部实现代码:

[python] view
plaincopy





#-*-coding:utf-8-*-

from numpy import *

import operator

from os import listdir

def classify0(inX, dataSet, labels, k):

dataSetSize = dataSet.shape[0]

#tile(A,(m,n))

print dataSet

print "----------------"

print tile(inX, (dataSetSize,1))

print "----------------"

diffMat = tile(inX, (dataSetSize,1)) - dataSet

print diffMat

sqDiffMat = diffMat**2

sqDistances = sqDiffMat.sum(axis=1)

distances = sqDistances**0.5

sortedDistIndicies = distances.argsort()

classCount={}

for i in range(k):

voteIlabel = labels[sortedDistIndicies[i]]

classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1

sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)

return sortedClassCount[0][0]

def img2vector(filename):

returnVect = zeros((1,1024))

fr = open(filename)

for i in range(32):

lineStr = fr.readline()

for j in range(32):

returnVect[0,32*i+j] = int(lineStr[j])

return returnVect

def handwritingClassTest():

hwLabels = []

trainingFileList = listdir('trainingDigits')

print trainingFileList

m = len(trainingFileList)

trainingMat = zeros((m,1024))

for i in range(m):

fileNameStr = trainingFileList[i]

fileStr = fileNameStr.split('.')[0]

classNumStr = int(fileStr.split('_')[0])

hwLabels.append(classNumStr)

#print hwLabels

#print fileNameStr

trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr)

#print trainingMat[i,:]

#print len(trainingMat[i,:])

testFileList = listdir('testDigits')

errorCount = 0.0

mTest = len(testFileList)

for i in range(mTest):

fileNameStr = testFileList[i]

fileStr = fileNameStr.split('.')[0]

classNumStr = int(fileStr.split('_')[0])

vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)

classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)

print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr)

if (classifierResult != classNumStr): errorCount += 1.0

print "\nthe total number of errors is: %d" % errorCount

print "\nthe total error rate is: %f" % (errorCount/float(mTest))

handwritingClassTest()

运行结果:源码文章尾可下载



java版本

先看看训练集和测试集:

训练集:



测试集:



训练集最后一列代表分类(0或者1)

代码实现:

KNN算法主体类:

[java] view
plaincopy





package Marchinglearning.knn2;

import java.util.ArrayList;

import java.util.Comparator;

import java.util.HashMap;

import java.util.List;

import java.util.Map;

import java.util.PriorityQueue;

/**

* KNN算法主体类

*/

public class KNN {

/**

* 设置优先级队列的比较函数,距离越大,优先级越高

*/

private Comparator<KNNNode> comparator = new Comparator<KNNNode>() {

public int compare(KNNNode o1, KNNNode o2) {

if (o1.getDistance() >= o2.getDistance()) {

return 1;

} else {

return 0;

}

}

};

/**

* 获取K个不同的随机数

* @param k 随机数的个数

* @param max 随机数最大的范围

* @return 生成的随机数数组

*/

public List<Integer> getRandKNum(int k, int max) {

List<Integer> rand = new ArrayList<Integer>(k);

for (int i = 0; i < k; i++) {

int temp = (int) (Math.random() * max);

if (!rand.contains(temp)) {

rand.add(temp);

} else {

i--;

}

}

return rand;

}

/**

* 计算测试元组与训练元组之前的距离

* @param d1 测试元组

* @param d2 训练元组

* @return 距离值

*/

public double calDistance(List<Double> d1, List<Double> d2) {

System.out.println("d1:"+d1+",d2"+d2);

double distance = 0.00;

for (int i = 0; i < d1.size(); i++) {

distance += (d1.get(i) - d2.get(i)) * (d1.get(i) - d2.get(i));

}

return distance;

}

/**

* 执行KNN算法,获取测试元组的类别

* @param datas 训练数据集

* @param testData 测试元组

* @param k 设定的K值

* @return 测试元组的类别

*/

public String knn(List<List<Double>> datas, List<Double> testData, int k) {

PriorityQueue<KNNNode> pq = new PriorityQueue<KNNNode>(k, comparator);

List<Integer> randNum = getRandKNum(k, datas.size());

System.out.println("randNum:"+randNum.toString());

for (int i = 0; i < k; i++) {

int index = randNum.get(i);

List<Double> currData = datas.get(index);

String c = currData.get(currData.size() - 1).toString();

System.out.println("currData:"+currData+",c:"+c+",testData"+testData);

//计算测试元组与训练元组之前的距离

KNNNode node = new KNNNode(index, calDistance(testData, currData), c);

pq.add(node);

}

for (int i = 0; i < datas.size(); i++) {

List<Double> t = datas.get(i);

System.out.println("testData:"+testData);

System.out.println("t:"+t);

double distance = calDistance(testData, t);

System.out.println("distance:"+distance);

KNNNode top = pq.peek();

if (top.getDistance() > distance) {

pq.remove();

pq.add(new KNNNode(i, distance, t.get(t.size() - 1).toString()));

}

}

return getMostClass(pq);

}

/**

* 获取所得到的k个最近邻元组的多数类

* @param pq 存储k个最近近邻元组的优先级队列

* @return 多数类的名称

*/

private String getMostClass(PriorityQueue<KNNNode> pq) {

Map<String, Integer> classCount = new HashMap<String, Integer>();

for (int i = 0; i < pq.size(); i++) {

KNNNode node = pq.remove();

String c = node.getC();

if (classCount.containsKey(c)) {

classCount.put(c, classCount.get(c) + 1);

} else {

classCount.put(c, 1);

}

}

int maxIndex = -1;

int maxCount = 0;

Object[] classes = classCount.keySet().toArray();

for (int i = 0; i < classes.length; i++) {

if (classCount.get(classes[i]) > maxCount) {

maxIndex = i;

maxCount = classCount.get(classes[i]);

}

}

return classes[maxIndex].toString();

}

}

KNN结点类,用来存储最近邻的k个元组相关的信息

[java] view
plaincopy





package Marchinglearning.knn2;

/**

* KNN结点类,用来存储最近邻的k个元组相关的信息

*/

public class KNNNode {

private int index; // 元组标号

private double distance; // 与测试元组的距离

private String c; // 所属类别

public KNNNode(int index, double distance, String c) {

super();

this.index = index;

this.distance = distance;

this.c = c;

}

public int getIndex() {

return index;

}

public void setIndex(int index) {

this.index = index;

}

public double getDistance() {

return distance;

}

public void setDistance(double distance) {

this.distance = distance;

}

public String getC() {

return c;

}

public void setC(String c) {

this.c = c;

}

}

KNN算法测试类

[java] view
plaincopy





package Marchinglearning.knn2;

import java.io.BufferedReader;

import java.io.File;

import java.io.FileReader;

import java.util.ArrayList;

import java.util.List;

/**

* KNN算法测试类

*/

public class TestKNN {

/**

* 从数据文件中读取数据

* @param datas 存储数据的集合对象

* @param path 数据文件的路径

*/

public void read(List<List<Double>> datas, String path){

try {

BufferedReader br = new BufferedReader(new FileReader(new File(path)));

String data = br.readLine();

List<Double> l = null;

while (data != null) {

String t[] = data.split(" ");

l = new ArrayList<Double>();

for (int i = 0; i < t.length; i++) {

l.add(Double.parseDouble(t[i]));

}

datas.add(l);

data = br.readLine();

}

} catch (Exception e) {

e.printStackTrace();

}

}

/**

* 程序执行入口

* @param args

*/

public static void main(String[] args) {

TestKNN t = new TestKNN();

String datafile = new File("").getAbsolutePath() + File.separator +"knndata2"+File.separator + "datafile.data";

String testfile = new File("").getAbsolutePath() + File.separator +"knndata2"+File.separator +"testfile.data";

System.out.println("datafile:"+datafile);

System.out.println("testfile:"+testfile);

try {

List<List<Double>> datas = new ArrayList<List<Double>>();

List<List<Double>> testDatas = new ArrayList<List<Double>>();

t.read(datas, datafile);

t.read(testDatas, testfile);

KNN knn = new KNN();

for (int i = 0; i < testDatas.size(); i++) {

List<Double> test = testDatas.get(i);

System.out.print("测试元组: ");

for (int j = 0; j < test.size(); j++) {

System.out.print(test.get(j) + " ");

}

System.out.print("类别为: ");

System.out.println(Math.round(Float.parseFloat((knn.knn(datas, test, 3)))));

}

} catch (Exception e) {

e.printStackTrace();

}

}

}

运行结果为:



资源下载:

python版本下载

java版本下载
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: