您的位置:首页 > 其它

机器学习实战之近邻算法kNN

2018-04-01 15:22 691 查看
kNN,k near neighbour,k近邻算法
依据的是用欧氏距离公式来计算距离,距离越近越相似。







# -*- coding: cp936 -*-
#改造成自己可以理解的代码
import os
from numpy import *

def classify0(inX,dataSet,labels,k):
    '''
    
    分类器,以dataSet为样本,分析出inX属于labels中的哪一类,
    k是dataSet中最接近的前k个'dataSet是样本数据,已处理成数组行列式形式,每一行为一个数据,就是一个向量。
    

    '''
    dataSetSize=dataSet.shape[0] #获取有多少行
    tile_inX=tile(inX,(dataSetSize,1))#创建有同样行数的行列式
    diffMat=tile_inX-dataSet#相减
    sqDiffMat=diffMat**2#平方
    sqDistance=sqDiffMat.sum(axis=1)#在行上求和
    distance=sqDistance**0.5#开方
    '欧氏公式,相减,平方,求和,开方,计算得到距离'
    sortedDistIndicies=distance.argsort() #得到排序后的距离的索引数组,只有数组类有这种方法
    #返回索引数组的目的是为了和labels列表对应起来。

    classCount={}#定义一个收集标签重复数量的字典
    for i in range(k):
        voteIlabel=labels[sortedDistIndicies[i]]#得到对应的标签
        if voteIlabel in classCount:
            classCount[voteIlabel]=classCount[voteIlabel]+1
        else:
            classCount[voteIlabel]=1
        #将标签和对应的数量存到字典中,标签作为索引,数量作为值。
        #在字典中按值对标签进行排序,返回成列表。
    sortedClassCount=sorted(classCount.keys(),key=lambda x:classCount[x],reverse=True)
    
    return sortedClassCount[0]
        
def img2vector(filename):
    '处理图像文件(txt格式),将32*32像素的图像转成1*1024的数组,这样每个图像文件变成一行数据'
    returnVect=zeros((1,1024))
    fr=open(filename)
    for i in range(32):
        lineStr=fr.readline()
        for j in range(32):
            returnVect[0,32*i+j]=int(lineStr[j])
    return returnVect

def getTrainingMat(dirName):
    '得到训练数据集'

    trainingFileList=os.listdir(dirName)
    m=len(trainingFileList)
    trainingMat=zeros((m,1024))
    hwLabels=[]
    for i in range(m):
        fileNameStr=trainingFileList[i]
        fileStr=fileNameStr.split('.')[0]
        #因为文件名是5_20.txt这样的样式,所以需要把最前的分类名5提取出来。
        classNumStr=fileStr.split('_')[0]
        hwLabels.append(classNumStr)
        trainingMat[i,:]=img2vector(dirName+os.sep+fileNameStr)

    return trainingMat,hwLabels

def storeMat((inputMat,inLabels)):
    '存储数据集和类标签,防止数据缺失'
    import pickle
    fw=open(r'D:\Code3\storeMatHW.txt','w')
    
    pickle.dump((inputMat,inLabels),fw)
    fw.close()

def grabMat():
    '提取数据集和类标签'
    
    import pickle
    try:
        fr=open(r'D:\Code3\storeMatHW.txt')
    except IOError:
        testHW()
        fr=open(r'D:\Code3\storeMatHW.txt')
    

    outMat,outLabels=pickle.load(fr)
    return outMat,outLabels

    
    

def testHW():
    '测试分类程序'
    testDirName=r'H:\study\python\machine learning\machinelearninginaction\Ch02\testDigits'
    trainingDirName=r'H:\study\python\machine learning\machinelearninginaction\Ch02\trainingDigits'

    trainingMat,hwLabels=getTrainingMat(trainingDirName)
    testMat,testLabels=getTrainingMat(testDirName)
    
    storeMat((trainingMat,hwLabels))#存储样本数据集合类标签。
    errorCount=0
    mTest=len(testMat)
    for i in range(mTest):

        
        classifierResult=classify0(testMat[i],trainingMat,hwLabels,3)
       
        if classifierResult != testLabels[i]:errorCount+=1

    print('\nthe total number of errors is:%s' %errorCount)
    print('\nthe total error rate is :%f' %(errorCount/float(mTest)))

def classifyHW(inFilename):
    '判断分类手写结果'
    inData=img2vector(inFilename)
    dataSet,labels=grabMat()
    classLabel=classify0(inData,dataSet,labels,3)
    print('手写的数字是: %s' %classLabel)

if __name__=='__main__':
    classifyHW(r'H:\study\python\machine learning\machinelearninginaction\Ch02\testDigits\8_20.txt')

    
    
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: