kdTree实践

一、kdTree 数据结构节点

  • left:  左子树
  • right:右子树
  • fea:所选轴(特征)
  • dataNode:所选轴中点的样本

二、kdTree实现主要包括两部分:

  • 1、建树  :计算轴方差,选出方差最大的轴,进行递归二分
  • 2、查询:根据当前kdTree节点轴的值与要查询节点轴的值比较,选择向左子树(或右子树)递归查询,得到两点间左子树(或右子树)的最小距离dis;根据当前kdTree节点轴的值与要查询节点轴的差值作比较,若差值较大,则说明(超球面是否与超矩形交割)要对右子树(或左子树)回溯

三、代码实现

 1 # -*- coding: utf-8 -*-
 2 """
 3 Created on Sun Sep 30 12:44:51 2018
 4 
 5 @author: Administrator
 6 """
 7 import pandas as pd
 8 import numpy as np
 9 import math
10 #定义treeNode
11 class Node:
12     def __init__(self,lTree,rTree,fea,dataNode):  #fea表示选择的轴,dataNode 以该节点进行分割左右子树
13         self.left=lTree;
14         self.right=rTree;
15         self.fea=fea;
16         self.dataNode=dataNode                #标签包含在其中、
17 
18 
19 ##直接用 DataFrame 作为数据结构
20 def getInfo():
21     data=[[2,3,''],[5,4,''],[9,6,''],[4,7,''],[8,1,''],[7,2,'']];              
22     data=pd.DataFrame(data,columns=['fea1','fea2','label'])
23     return data;
24 
25 # 计算方差,选择轴 根据轴方差
26 def calSq(data):
27     sq=data.var();        
28     pos=data.columns[0];
29     val=sq[0];
30     for i in data.columns[1:-1]:   #选择方差最大的
31         if(val<sq[i]):
32             val=sq[i];
33             pos=i;
34     return pos;
35 
36  #按轴将数据拆分
37 def splitAxis(data):  
38     fea=calSq(data);
39     sortData=data.sort_values(by=fea);   #按轴排序
40     sortData=(np.array(sortData)).tolist();  #转list
41     dataNode=pd.DataFrame( [ sortData[len(sortData)//2] ],    columns=list(data.columns));        #数据节点
42     leftSet=pd.DataFrame( sortData[0:len(sortData)//2] , columns=list(data.columns) );    #左子树
43     rightSet=pd.DataFrame(sortData[len(sortData)//2+1:] , columns=list(data.columns) );                  #右子树
44     return fea,dataNode,leftSet,rightSet;
45 
46 #建树
47 def createTree(data):   #递归建树
48     if(len(data)>0):          #如果有数据
49         fea,dataNode,leftSet,rightSet=splitAxis(data)
50         treeNode=Node(None,None,fea,dataNode);
51         if(len(leftSet)>0):            #左边是否可分
52             treeNode.left=createTree(leftSet);
53         if(len(rightSet)>0):         #右边是否可分
54             treeNode.right=createTree(rightSet);
55         return treeNode;
56   
57 #递归搜索      
58 def search(tree,preNode):    #perNode 表示要查询一个样本;
59     dis=0;
60     for i in tree.dataNode.columns[:-1]:         #计算距离
61         dis=dis+( tree.dataNode[i][0]-preNode[i][0] )**2;
62     dis=math.sqrt(dis);
63     label=tree.dataNode[tree.dataNode.columns[-1]][0];  #当前节点标记
64     labelL='';
65     labelR='';
66     if(tree.left!=None and preNode[tree.fea][0] < tree.dataNode[tree.fea][0] ): #左边搜索
67         disL,labelL = search( tree.left, preNode );
68         if(disL<dis):                                                           #取距离最小的
69             dis=disL
70             label=labelL;
71         if( dis >  abs(preNode[tree.fea][0] - tree.dataNode[tree.fea][0])): #超球面是否与超矩形交割 判断是否要回溯
72             disHR,labelHR=search(tree.right,preNode);                                     #回溯右子树
73             if(disHR<dis):
74                 return disHR,labelHR
75             else:
76                 return dis,label
77         
78     if(tree.right!=None and preNode[tree.fea][0] >= tree.dataNode[tree.fea][0] ): #右边搜索
79         disR,labelR=search(tree.right,preNode);
80         if(disR < dis):                                                          #取距离最小的
81             dis=disR;
82             label=labelR;
83         if( dis >  abs(preNode[tree.fea][0] - tree.dataNode[tree.fea][0])):  #超球面是否与超矩形交割 判断是否要回溯
84             disHL,labelHL=search(tree.left,preNode);                        #回溯左子树
85             if(disHL<dis):
86                 return disHL,labelHL
87             else:
88                 return dis,label
89     return dis,label;
90     
91 data=getInfo();
92 root=createTree(data);
93 test=pd.DataFrame( [ [7.1,1] ], columns=list(data.columns[:-1]));
94 dis,label=search(root,test)

 

posted @ 2018-10-04 20:58  bear_ge  阅读(622)  评论(0编辑  收藏  举报