python kd树 搜索 代码

  kd树就是一种对k维空间中的实例点进行存储以便对其进行快速检索的树形数据结构,可以运用在k近邻法中,实现快速k近邻搜索。构造kd树相当于不断地用垂直于坐标轴的超平面将k维空间切分,依次选择坐标轴对空间进行切分,选择训练实例点在选定坐标轴上的中位数为切分点。具体kd树的原理可以参考kd树的原理。

  代码是参考《统计学习方法》k近邻 kd树的python实现得到

  首先创建一个类,用于表示树的节点,包括:该节点的值,用于划分左右子树的切分轴,左子树,右子树

class decisionnode:
    def __init__(self,value=None,col=None,rb=None,lb=None):
        self.value=value
        self.col=col
        self.rb=rb
        self.lb=lb

  切分点为坐标轴上的中值,下面代码求得一个序列的中值

def median(x):
    n=len(x)
    x=list(x)
    x_order=sorted(x)
    return x_order[n//2],x.index(x_order[n//2])

然后就可以构造一颗kd树,左子树小于切分点,右子树大于切分点,data是输入的数据

def buildtree(x,j=0):
    rb=[]
    lb=[]
    m,n=x.shape
    if m==0: return None
    edge,row=median(x[:,j].copy())
    for i in range(m):
        if x[i][j]>edge: 
            rb.append(i)
        if x[i][j]<edge:
            lb.append(i)
    rb_x=x[rb,:]
    lb_x=x[lb,:]
    rightBranch=buildtree(rb_x,(j+1)%n)
    leftBranch=buildtree(lb_x,(j+1)%n)
    return decisionnode(x[row,:],j,rightBranch,leftBranch)

  接下来是树的搜索过程,可以用下图表示树的搜索过程,具体过程可以参考kd树的原理。

  

  代码如下:

#搜索树:nearestPoint,nearestValue均为全局变量
def traveltree(node,point):
    global nearestPoint,nearestValue
    if node==None: return 
    print(node.value)
    print('---')
    col=node.col
    if point[col]>node.value[col]:
        traveltree(node.rb,point)
    if point[col]<node.value[col]:
        traveltree(node.lb,point)
    dis=dist(node.value,point)
    print(dis)
    if dis<nearestValue:
        nearestPoint=node
        nearestValue=dis
        #print('nearestPoint,nearestValue' % (nearestPoint,nearestValue))
    if node.rb!=None or node.lb!=None:
        if abs(point[node.col] - node.value[node.col]) < nearestValue:
            if point[node.col]<node.value[node.col]:
                traveltree(node.rb,point)
            if point[node.col]>node.value[node.col]:
                traveltree(node.lb,point)
        
def searchtree(tree,aim):
    global nearestPoint,nearestValue
    #nearestPoint=None
    nearestValue=float('inf')
    traveltree(tree,aim)
    return nearestPoint
        
    
def dist(x1, x2): #欧式距离的计算  
    return ((np.array(x1) - np.array(x2)) ** 2).sum() ** 0.5

 完整代码在此处取

 1 import numpy as np
 2 from numpy import array
 3 class decisionnode:
 4     def __init__(self,value=None,col=None,rb=None,lb=None):
 5         self.value=value
 6         self.col=col
 7         self.rb=rb
 8         self.lb=lb
 9         
10 #读取数据并将数据转换为矩阵形式        
11 def readdata(filename):    
12     data=open(filename).readlines()
13     x=[]
14     for line in data:
15         line=line.strip().split('\t')
16         x_i=[]
17         for num in line:
18             num=float(num)
19             x_i.append(num)
20         x.append(x_i)
21     x=array(x)
22     return x
23 
24 #求序列的中值    
25 def median(x):
26     n=len(x)
27     x=list(x)
28     x_order=sorted(x)
29     return x_order[n//2],x.index(x_order[n//2])
30 
31 #以j列的中值划分数据,左小右大,j=节点深度%列数    
32 def buildtree(x,j=0):
33     rb=[]
34     lb=[]
35     m,n=x.shape
36     if m==0: return None
37     edge,row=median(x[:,j].copy())
38     for i in range(m):
39         if x[i][j]>edge: 
40             rb.append(i)
41         if x[i][j]<edge:
42             lb.append(i)
43     rb_x=x[rb,:]
44     lb_x=x[lb,:]
45     rightBranch=buildtree(rb_x,(j+1)%n)
46     leftBranch=buildtree(lb_x,(j+1)%n)
47     return decisionnode(x[row,:],j,rightBranch,leftBranch)
48 
49 #搜索树:nearestPoint,nearestValue均为全局变量
50 def traveltree(node,point):
51     global nearestPoint,nearestValue
52     if node==None: return 
53     print(node.value)
54     print('---')
55     col=node.col
56     if point[col]>node.value[col]:
57         traveltree(node.rb,point)
58     if point[col]<node.value[col]:
59         traveltree(node.lb,point)
60     dis=dist(node.value,point)
61     print(dis)
62     if dis<nearestValue:
63         nearestPoint=node
64         nearestValue=dis
65         #print('nearestPoint,nearestValue' % (nearestPoint,nearestValue))
66     if node.rb!=None or node.lb!=None:
67         if abs(point[node.col] - node.value[node.col]) < nearestValue:
68             if point[node.col]<node.value[node.col]:
69                 traveltree(node.rb,point)
70             if point[node.col]>node.value[node.col]:
71                 traveltree(node.lb,point)
72         
73 def searchtree(tree,aim):
74     global nearestPoint,nearestValue
75     #nearestPoint=None
76     nearestValue=float('inf')
77     traveltree(tree,aim)
78     return nearestPoint
79         
80     
81 def dist(x1, x2): #欧式距离的计算  
82     return ((np.array(x1) - np.array(x2)) ** 2).sum() ** 0.5  
kdtree

 

posted @ 2018-02-09 19:51  Alice_鹿_Bambi  阅读(3383)  评论(2编辑  收藏  举报