您的位置:首页 > 编程语言 > Python开发

[置顶] 《统计学习方法》k近邻 kd树的python实现

2017-08-17 15:22 501 查看
最近在阅读统计学习方法的k近邻这一章的时候,用Python实现了kd树中的算法。
先大概说下为啥要用kd树,k近邻其实实现起来很简单,就是求下欧式距离,但是传统的k近邻需要遍历所有的样本,而kd树进行了改进,使其不用计算所有的距离。
分为两部分,一个是kd树建立,一个是kd树的搜索。来看代码。
# --*-- coding:utf-8 --*--
import numpy as np对了,别忘了,先定义一下字符集还有包。
首先我们先实现一个结点类。
class Node:
def __init__(self, data, lchild = None, rchild = None):
self.data = data
self.lchild = lchild
self.rchild = rchild
一个结点包含着结点域,左孩子,右孩子。
然后是创建kd树的代码。
def create(self, dataSet, depth): #创建kd树,返回根结点
if (len(dataSet) > 0):
m, n = np.shape(dataSet) #求出样本行,列
midIndex = m / 2 #中间数的索引位置
axis = depth % n #判断以哪个轴划分数据
sortedDataSet = self.sort(dataSet, axis) #进行排序
node = Node(sortedDataSet[midIndex]) #将节点数据域设置为中位数,具体参考下书本
# print sortedDataSet[midIndex]
leftDataSet = sortedDataSet[: midIndex] #将中位数的左边创建2改副本
rightDataSet = sortedDataSet[midIndex+1 :]
print leftDataSet
print rightDataSet
node.lchild = self.create(leftDataSet, depth+1) #将中位数左边样本传入来递归创建树
node.rchild = self.create(rightDataSet, depth+1)
return node
else:
return None书中讲到了,需要按轴(depth(深度) mod n(特征数)+1)划分中位数,然后决定插入数据到左结点,右结点。对了,注意一下为什么上面的轴的公式是depth(深度) mod n(特征数),这是因为python的数组下标是从0开始的。def sort(self, dataSet, axis): #采用冒泡排序,利用aixs作为轴进行划分
sortDataSet = dataSet[:] #由于不能破坏原样本,此处建立一个副本
m, n = np.shape(sortDataSet)
for i in range(m):
for j in range(0, m - i - 1):
if (sortDataSet[j][axis] > sortDataSet[j+1][axis]):
temp = sortDataSet[j]
sortDataSet[j] = sortDataSet[j+1]
sortDataSet[j+1] = temp
print sortDataSet
return sortDataSet创建树的时候为了找中位数,需要按轴排序,来找,这里我用了冒泡排序。
def preOrder(self, node):
if node != None:
print "tttt->%s" % node.data
self.preOrder(node.lchild)
self.preOrder(node.rchild)
当然我选择了先序遍历来简单检查下树有没有问题。def search(self, tree, x):
self.nearestPoint = None #保存最近的点
self.nearestValue = 0 #保存最近的值
def travel(node, depth = 0): #递归搜索
if node != None: #递归终止条件
n = len(x) #特征数
axis = depth % n #计算轴
if x[axis] < node.data[axis]: #如果数据小于结点,则往左结点找
travel(node.lchild, depth+1)
else:
travel(node.rchild, depth+1)

#以下是递归完毕后,往父结点方向回朔
distNodeAndX = self.dist(x, node.data) #目标和节点的距离判断
if (self.nearestPoint == None): #确定当前点,更新最近的点和最近的值
self.nearestPoint = node.data
self.nearestValue = distNodeAndX
elif (self.nearestValue > distNodeAndX):
self.nearestPoint = node.data
self.nearestValue = distNodeAndX

print node.data, depth, self.nearestValue, node.data[axis], x[axis]
if (abs(x[axis] - node.data[axis]) <= self.nearestValue): #确定是否需要去子节点的区域去找(圆的判断)
if x[axis] < node.data[axis]:
travel(node.rchild, depth+1)
else:
travel(node.lchild, depth + 1)
travel(tree)
return self.nearestPoint

def dist(self, x1, x2): #欧式距离的计算
return ((np.array(x1) - np.array(x2)) ** 2).sum() ** 0.5搜索树的时候比较麻烦,首先先说下原理吧。
(1) 在kd树中找出包含目标点x的叶结点:从根结点出发,递归的向下访问kd树。若目标点当前维的坐标值小于切分点的坐标值,则移动到左子结点,否则移动到右子结点。直到子结点为叶结点为止;
(2) 以此叶结点为“当前最近点”;
(3) 递归的向上回退,在每个结点进行以下操作:
  (a) 如果该结点保存的实例点比当前最近点距目标点更近,则以该实例点为“当前最近点”;
  (b) 当前最近点一定存在于该结点一个子结点对应的区域。检查该子结点的父结点的另一个子结点对应的区域是否有更近的点。具体的,检查另一个子结点对应的区域是否与以目标点为球心、以目标点与“当前最近点”间的距离为半径的超球体相交。如果相交,可能在另一个子结点对应的区域内存在距离目标更近的点,移动到另一个子结点。接着,递归的进行最近邻搜索。如果不相交,向上回退。
(4) 当回退到根结点时,搜索结束。最后的“当前最近点”即为x的最近邻点。
注意了,先按步骤找到叶结点,然后回朔的时候要做两件事,1、是更新最新点,2、是检查是否需要检查父结节点的另外一个结点的区域。
(1)容易实现,但是(2)的原理大概是判断跟目标点最近的值形成的一个圆是否跟父结点按轴分的那条线有交集。
说白了,就是公式:|目标值(按轴读值) - 父节点(按轴读值)| < 最近的值(圆的半径)
如果找到了的话,把另一结点重新递归一次就好了。
最后我们来运行一下。
dataSet = [[2, 3],
[5, 4],
[9, 6],
[4, 7],
[8, 1],
[7, 2]]
x = [5, 3]
kdtree = KdTree()
tree = kdtree.create(dataSet, 0)
kdtree.preOrder(tree)
print kdtree.search(tree, x)结果输出(5, 4)

# --*-- coding:utf-8 --*--
import numpy as np
class Node:
def __init__(self, data, lchild = None, rchild = None):
self.data = data
self.lchild = lchild
self.rchild = rchild

class KdTree:
def __init__(self):
self.kdTree = None

def create(self, dataSet, depth): #创建kd树,返回根结点
if (len(dataSet) > 0):
m, n = np.shape(dataSet) #求出样本行,列
midIndex = int(m / 2) #中间数的索引位置
axis = depth % n #判断以哪个轴划分数据
sortedDataSet = self.sort(dataSet, axis) #进行排序
node = Node(sortedDataSet[midIndex]) #将节点数据域设置为中位数,具体参考下书本
# print sortedDataSet[midIndex]
leftDataSet = sortedDataSet[: midIndex] #将中位数的左边创建2改副本
rightDataSet = sortedDataSet[midIndex+1 :]
print(leftDataSet)
print(rightDataSet)
node.lchild = self.create(leftDataSet, depth+1) #将中位数左边样本传入来递归创建树
node.rchild = self.create(rightDataSet, depth+1)
return node
else:
return None

def sort(self, dataSet, axis): #采用冒泡排序,利用aixs作为轴进行划分
sortDataSet = dataSet[:] #由于不能破坏原样本,此处建立一个副本
m, n = np.shape(sortDataSet)
for i in range(m):
for j in range(0, m - i - 1):
if (sortDataSet[j][axis] > sortDataSet[j+1][axis]):
temp = sortDataSet[j]
sortDataSet[j] = sortDataSet[j+1]
sortDataSet[j+1] = temp
print(sortDataSet)
return sortDataSet

def preOrder(self, node):
if node != None:
print("tttt->%s" % node.data)
self.preOrder(node.lchild)
self.preOrder(node.rchild)

# def search(self, tree, x):
# node = tree
# depth = 0
# while (node != None):
# print node.data
# n = len(x) #特征数
# axis = depth % n
# if x[axis] < node.data[axis]:
# node = node.lchild
# else:
# node = node.rchild
# depth += 1
def search(self, tree, x):
self.nearestPoint = None #保存最近的点
self.nearestValue = 0 #保存最近的值
def travel(node, depth = 0): #递归搜索
if node != None: #递归终止条件
n = len(x) #特征数
axis = depth % n #计算轴
if x[axis] < node.data[axis]: #如果数据小于结点,则往左结点找
travel(node.lchild, depth+1)
else:
travel(node.rchild, depth+1)

#以下是递归完毕后,往父结点方向回朔
distNodeAndX = self.dist(x, node.data) #目标和节点的距离判断
if (self.nearestPoint == None): #确定当前点,更新最近的点和最近的值
self.nearestPoint = node.data
self.nearestValue = distNodeAndX
elif (self.nearestValue > distNodeAndX):
self.nearestPoint = node.data
self.nearestValue = distNodeAndX

print(node.data, depth, self.nearestValue, node.data[axis], x[axis])
if (abs(x[axis] - node.data[axis]) <= self.nearestValue): #确定是否需要去子节点的区域去找(圆的判断)
if x[axis] < node.data[axis]:
travel(node.rchild, depth+1)
else:
travel(node.lchild, depth + 1)
travel(tree)
return self.nearestPoint

def dist(self, x1, x2): #欧式距离的计算
return ((np.array(x1) - np.array(x2)) ** 2).sum() ** 0.5

dataSet = [[2, 3],
[5, 4],
[9, 6],
[4, 7],
[8, 1],
[7, 2]]
x = [5, 3]
kdtree = KdTree()
tree = kdtree.create(dataSet, 0)
kdtree.preOrder(tree)
print(kdtree.search(tree, x))
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: