python kd树 搜索 代码
kd树就是一种对k维空间中的实例点进行存储以便对其进行快速检索的树形数据结构,可以运用在k近邻法中,实现快速k近邻搜索。构造kd树相当于不断地用垂直于坐标轴的超平面将k维空间切分,依次选择坐标轴对空间进行切分,选择训练实例点在选定坐标轴上的中位数为切分点。具体kd树的原理可以参考kd树的原理。
代码是参考《统计学习方法》k近邻 kd树的python实现得到
首先创建一个类,用于表示树的节点,包括:该节点的值,用于划分左右子树的切分轴,左子树,右子树
class decisionnode:
def __init__(self,value=None,col=None,rb=None,lb=None):
self.value=value
self.col=col
self.rb=rb
self.lb=lb
切分点为坐标轴上的中值,下面代码求得一个序列的中值
def median(x):
n=len(x)
x=list(x)
x_order=sorted(x)
return x_order[n//2],x.index(x_order[n//2])
然后就可以构造一颗kd树,左子树小于切分点,右子树大于切分点,data是输入的数据
def buildtree(x,j=0):
rb=[]
lb=[]
m,n=x.shape
if m==0: return None
edge,row=median(x[:,j].copy())
for i in range(m):
if x[i][j]>edge:
rb.append(i)
if x[i][j]<edge:
lb.append(i)
rb_x=x[rb,:]
lb_x=x[lb,:]
rightBranch=buildtree(rb_x,(j+1)%n)
leftBranch=buildtree(lb_x,(j+1)%n)
return decisionnode(x[row,:],j,rightBranch,leftBranch)
接下来是树的搜索过程,可以用下图表示树的搜索过程,具体过程可以参考kd树的原理。
代码如下:
#搜索树:nearestPoint,nearestValue均为全局变量
def traveltree(node,point):
global nearestPoint,nearestValue
if node==None: return
print(node.value)
print('---')
col=node.col
if point[col]>node.value[col]:
traveltree(node.rb,point)
if point[col]<node.value[col]:
traveltree(node.lb,point)
dis=dist(node.value,point)
print(dis)
if dis<nearestValue:
nearestPoint=node
nearestValue=dis
#print('nearestPoint,nearestValue' % (nearestPoint,nearestValue))
if node.rb!=None or node.lb!=None:
if abs(point[node.col] - node.value[node.col]) < nearestValue:
if point[node.col]<node.value[node.col]:
traveltree(node.rb,point)
if point[node.col]>node.value[node.col]:
traveltree(node.lb,point)
def searchtree(tree,aim):
global nearestPoint,nearestValue
#nearestPoint=None
nearestValue=float('inf')
traveltree(tree,aim)
return nearestPoint
def dist(x1, x2): #欧式距离的计算
return ((np.array(x1) - np.array(x2)) ** 2).sum() ** 0.5
完整代码在此处取
1 import numpy as np
2 from numpy import array
3 class decisionnode:
4 def __init__(self,value=None,col=None,rb=None,lb=None):
5 self.value=value
6 self.col=col
7 self.rb=rb
8 self.lb=lb
9
10 #读取数据并将数据转换为矩阵形式
11 def readdata(filename):
12 data=open(filename).readlines()
13 x=[]
14 for line in data:
15 line=line.strip().split('\t')
16 x_i=[]
17 for num in line:
18 num=float(num)
19 x_i.append(num)
20 x.append(x_i)
21 x=array(x)
22 return x
23
24 #求序列的中值
25 def median(x):
26 n=len(x)
27 x=list(x)
28 x_order=sorted(x)
29 return x_order[n//2],x.index(x_order[n//2])
30
31 #以j列的中值划分数据,左小右大,j=节点深度%列数
32 def buildtree(x,j=0):
33 rb=[]
34 lb=[]
35 m,n=x.shape
36 if m==0: return None
37 edge,row=median(x[:,j].copy())
38 for i in range(m):
39 if x[i][j]>edge:
40 rb.append(i)
41 if x[i][j]<edge:
42 lb.append(i)
43 rb_x=x[rb,:]
44 lb_x=x[lb,:]
45 rightBranch=buildtree(rb_x,(j+1)%n)
46 leftBranch=buildtree(lb_x,(j+1)%n)
47 return decisionnode(x[row,:],j,rightBranch,leftBranch)
48
49 #搜索树:nearestPoint,nearestValue均为全局变量
50 def traveltree(node,point):
51 global nearestPoint,nearestValue
52 if node==None: return
53 print(node.value)
54 print('---')
55 col=node.col
56 if point[col]>node.value[col]:
57 traveltree(node.rb,point)
58 if point[col]<node.value[col]:
59 traveltree(node.lb,point)
60 dis=dist(node.value,point)
61 print(dis)
62 if dis<nearestValue:
63 nearestPoint=node
64 nearestValue=dis
65 #print('nearestPoint,nearestValue' % (nearestPoint,nearestValue))
66 if node.rb!=None or node.lb!=None:
67 if abs(point[node.col] - node.value[node.col]) < nearestValue:
68 if point[node.col]<node.value[node.col]:
69 traveltree(node.rb,point)
70 if point[node.col]>node.value[node.col]:
71 traveltree(node.lb,point)
72
73 def searchtree(tree,aim):
74 global nearestPoint,nearestValue
75 #nearestPoint=None
76 nearestValue=float('inf')
77 traveltree(tree,aim)
78 return nearestPoint
79
80
81 def dist(x1, x2): #欧式距离的计算
82 return ((np.array(x1) - np.array(x2)) ** 2).sum() ** 0.5