您的位置:首页 > 编程语言 > Python开发

python数据挖掘实践第一章 KNN算法,以及算法的实现

2017-04-24 10:57 1061 查看
KNN思路:

1.利用测试集和训练数据的相似性(目标距离的远近),取前K个

2.计算前K个分类的数据多少

3.选择最多的一个分目标分类

算法思路:

不考虑时间复杂度

1、读取测试集,或者自动生成

2、遍历测试集

3、与训练集的每一行,进行计算距离(采用两点的距离公式),同时保存每一行计算的距离,和分类的值(生成一个新的矩阵或者集合)

4、选取前K个(这里要对距离和分类进行排序,排序算法)

5、要对前K个(已经排好离的分类,进行分组;并对C的数量进行排序,选取最多的就是目标)

下面是python 2.7的实现

from numpy import *

import operator

from os import listdir

import pylab as pl

from matplotlib import pyplot as plt

def file2Matrix(filename):

fr = open(filename)

numberOfLines = len(fr.readlines()) #get the number of lines in the file

returnMat = zeros((numberOfLines,3)) #prepare matrix to return

classLabelVector = [] #prepare labels return

fr = open(filename)

index = 0

for line in fr.readlines():

line = line.strip()

listFromLine = line.split('\t')

returnMat[index,:] = listFromLine[0:3]

classLabelVector.append(int(listFromLine[-1]))

index += 1

return returnMat,classLabelVector

##

def createDataSet():

group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])

label =['A','A','B','B']

return label,group

def classify0(inX,dataSet,labels,k):

dataSetSize = dataSet.shape[0]

print 'dataSet.shape[1]',dataSet.shape[1]

diffMat = tile(inX, (dataSetSize,1)) - dataSet

sqDiffMat = diffMat**2

sqDistances = sqDiffMat.sum(axis=1)

distances = sqDistances**0.5

sortedDistIndicies = distances.argsort()

classCount={}

for a in range(k):

voteIlabel=labels[sortedDistIndicies[a]]

classCount[voteIlabel]=classCount.get(voteIlabel,0)+1

sortedClassCount = sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True)

return sortedClassCount[0][0]

def classify01(inX, dataSet, labels, k):

dataSetSize = dataSet.shape[0]

diffMat = tile(inX, (dataSetSize,1)) - dataSet

sqDiffMat = diffMat**2

sqDistances = sqDiffMat.sum(axis=1)

distances = sqDistances**0.5

sortedDistIndicies = distances.argsort()

classCount={}

for i in range(k):

voteIlabel = labels[sortedDistIndicies[i]]

classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1

sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)

return sortedClassCount[0][0]

def autoNorm(datamingDataMat):

minValue = datamingDataMat.min(0)

maxValue = datamingDataMat.max(0)

#print minValue

#print maxValue

range = maxValue-minValue

print range

#print shape(datamingDataMat)

normDataSet = zeros(shape(datamingDataMat))

#print normDataSet

m = datamingDataMat.shape[0]

normDataSet = datamingDataMat-tile(minValue,(m,1))

normDataSet = datamingDataMat/tile(range,(m,1))

return normDataSet ,range,minValue

if __name__ == "__main__":

label,group = createDataSet()

print label

##pl.plot(group,'o')

##pl.show()

#classify0([0,0],group,label,3)

#读文件

#datamingDataMat,datingLabels = file2Matrix('E:\\数据挖掘资料\\《机器学习实战》源代码\\machinelearninginaction\\Ch02\\datingTestSet2.txt')

#

## print len(datingLabels)

## print datamingDataMat

## fig = plt.figure()

## ax = fig.add_subplot(111)

## ax.scatter(datamingDataMat[:,0],datamingDataMat[:,1],

## 15.0*array(datingLabels),15.0*array(datingLabels))

## plt.show()

## normMat,ranges,minValue =autoNorm(datamingDataMat)

##

## #print 'normorize:',normMat,'ranges:',ranges

##

## m = normMat.shape[0]

##

## print m

hoRatio = 0.10 #hold out 10%

datingDataMat,datingLabels = file2Matrix('E:\\数据挖掘资料\\《机器学习实战》源代码\\machinelearninginaction\\Ch02\\datingTestSet2.txt') #load data setfrom file

normMat, ranges, minVals = autoNorm(datingDataMat)

m = normMat.shape[0]

numTestVecs = int(m*hoRatio)

errorCount = 0.0

for i in range(numTestVecs):

classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],3)

print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, datingLabels[i])

if (classifierResult != datingLabels[i]): errorCount += 1.0

print "the total error rate is: %f" % (errorCount/float(numTestVecs))

print errorCount

---------------------------------------------------------------------------------------

以下是java 代码的实现

public class tztKNN {

/**

* set unique coparator function,distance biggest,

*/

private static Comparator<KNNNode> comparator = new Comparator<KNNNode>() {

@Override

public int compare(KNNNode o1, KNNNode o2) {

// TODO Auto-generated method stub

if(o1.getDistance() >=o2.getDistance()){

return -1;

} else{

return 1;

}

}

};

public static void main(String[] args) {

// TODO Auto-generated method stub

//第一加载文件

String fileName="E:\\数据挖掘资料\\《机器学习实战》源代码\\datingTestSet2.txt";

List<List<Float>>list = ReadFile(fileName);

/* System.out.println(list.get(0).toString());

System.out.println(list.get(1).toString());*/

List<Float>testlist =new ArrayList();

testlist.add(40920f);

testlist.add(8.326976f);

testlist.add(1.673904f);

//第二步

//可以直接计算距离

PriorityQueue<KNNNode> result = new PriorityQueue<KNNNode>();

result=Distances(testlist,list,3);

System.out.println("---------result-------");

//在这里计算出现最多的就是目标分类

String classif_re=getMostClass(result);

System.out.println(classif_re);

}

private static String getMostClass(PriorityQueue<KNNNode> result) {

Map<String,Integer> classCount = new HashMap<>();

for(int i = 0; i<5; i++){

KNNNode node = result.poll();

String c = node.getC();

System.out.println("c:"+c);

if(classCount.containsKey(c)){

classCount.put(c, classCount.get(c)+1);

}

else{

classCount.put(c, 1);

}

}

List<Entry<String,Integer>> list =new ArrayList<Entry<String,Integer>>(classCount.entrySet());

//最后通过Collections.sort(List l, Comparator c)方法来进行排序,代码如下:

Collections.sort(list, new Comparator<Map.Entry<String, Integer>>() {

public int compare(Map.Entry<String, Integer> o1,

Map.Entry<String, Integer> o2) {

return (o2.getValue() - o1.getValue());

}

});

return list.get(0).getKey();

// TODO Auto-generated method stub

/*String re="";

Map<String, Integer> classCount = new HashMap<String, Integer>() ;

int pqsize = 5;

for (int i =0; i <pqsize; i++){

KNNNode node = result.remove();

String c = node.getC();

if (classCount.containsKey(c)){

classCount.put(c, classCount.get(c)+1);

}else{

classCount.put(c, 1);

}

}

int maxIndex =-1;

int maxCount=0;

Object[]classes = classCount.keySet().toArray();

for (int i = 0; i<classes.length; i++){

if (classCount.get(classes[i])>maxCount){

maxIndex = i;

maxCount = classCount.get(classes[i]);

}

}*/

//return (String) classes[maxIndex];

}

private static PriorityQueue<KNNNode> Distances(List<Float> testlist, List<List<Float>> list,int k) {

// TODO Auto-generated method stub

double distance = 0.00;

PriorityQueue<KNNNode> pq = new PriorityQueue<KNNNode>(k,comparator);

// 要存储每个元组的顺序,距离,和他们的分类

for(int i = 0; i<list.size(); i++){

List<Float> t = list.get(i);

String c = t.get(t.size() - 1).toString();

for (int j= 0; j < testlist.size(); j++) {

float test0 = testlist.get(j);

float base0 = t.get(j);

distance = (test0- base0)*(test0- base0);

}

System.out.println("the:"+i+" ci:"+"c:"+c+" distace:"+distance);

KNNNode node = new KNNNode(i,distance,c);

pq.add(node);

}

return pq;

}

private static List<List<Float>> ReadFile(String fileName) {

// TODO Auto-generated method stub

List<List<Float>> datasets = new ArrayList<>();

try {

File file = new File(fileName);

if(file.exists()){

InputStreamReader reader=new InputStreamReader(new FileInputStream(file),"GBK" );

BufferedReader bufferedReader = new BufferedReader(reader);

String lineTxt = "";

while((lineTxt =bufferedReader.readLine())!=null){

//System.out.println(lineTxt);

//分离分行数据

ArrayList<Float> aa = new ArrayList<>();

String[] shuxing =lineTxt.split("
");

for (String sx :shuxing){

//System.out.println(sx);

aa.add(Float.parseFloat(sx));

}

datasets.add(aa);

}

}else{

System.out.println("找不到指定的文件");

}

} catch (Exception e) {

// TODO: handle exception

System.out.println("读取文件内容出错");

e.printStackTrace();

}

return datasets;

}

}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
相关文章推荐