kd树 求k近邻 python 代码
之前两篇随笔介绍了kd树的原理,并用python实现了kd树的构建和搜索,具体可以参考
kd树常与knn算法联系在一起,knn算法通常要搜索k近邻,而不仅仅是最近邻,下面的代码将利用kd树搜索目标点的k个近邻。
首先还是创建一个类,用于保存结点的值,左右子树,以及用于划分左右子树的切分轴
class decisionnode:
def __init__(self,value=None,col=None,rb=None,lb=None):
self.value=value
self.col=col
self.rb=rb
self.lb=lb
切分点为坐标轴上的中值,下面代码求得一个序列的中值
def median(x):
n=len(x)
x=list(x)
x_order=sorted(x)
return x_order[n//2],x.index(x_order[n//2])
然后按照左子树大于切分点,右子树小于切分点的规则构造kd树,其中data是输入的数据
#以j列的中值划分数据,左小右大,j=节点深度%列数
def buildtree(x,j=0):
rb=[]
lb=[]
m,n=x.shape
if m==0: return None
edge,row=median(x[:,j].copy())
for i in range(m):
if x[i][j]>edge:
rb.append(i)
if x[i][j]<edge:
lb.append(i)
rb_x=x[rb,:]
lb_x=x[lb,:]
rightBranch=buildtree(rb_x,(j+1)%n)
leftBranch=buildtree(lb_x,(j+1)%n)
return decisionnode(x[row,:],j,rightBranch,leftBranch)
接下来就是搜索树得到k近邻的过程,与搜索最近邻的过程大致相同,需要创建一个字典knears,用于存储k近邻的点以及与目标点的距离(欧氏距离)
搜索的过程为:
(1)第一步还是遍历树,找到目标点所属区域对应的叶节点
(2)从叶结点依次向上回退,按照寻找最近邻点的方法回退到父节点,并判断其另一个子节点对区域内是否可能存在k近邻点,具体的,在每个结点上进行以下操作:
(a)如果字典中的成员个数不足k个,将该结点加入字典
(b)如果字典中的成员不少于k个,判断该结点与目标结点之间的距离是否不大于字典中各结点所对应距离的的最大值,如果不大于,便将其加入到字典中
(c)对于父节点来说,如果目标点与其切分轴之间的距离不大于字典中各结点所对应距离的的最大值,便需要访问该父节点的另一个子节点
(3)每当字典中增加新成员,就按距离值对字典进行降序排序,将得到的列表赋值给poinelist,pointlist[0][1]便是字典中各结点所对应距离的最大值
(4)当回退到根节点并完成对其操作时,pointlist中后k个结点就是目标点的k近邻
代码如下:
#搜索树:输出目标点的近邻点
def traveltree(node,aim):
global pointlist #存储排序后的k近邻点和对应距离
if node==None: return
col=node.col
if aim[col]>node.value[col]:
traveltree(node.rb,aim)
if aim[col]<node.value[col]:
traveltree(node.lb,aim)
dis=dist(node.value,aim)
if len(knears)<k:
knears.setdefault(tuple(node.value.tolist()),dis)#列表不能作为字典的键
pointlist=sorted(knears.items(),key=lambda item: item[1],reverse=True)
elif dis<=pointlist[0][1]:
knears.setdefault(tuple(node.value.tolist()),dis)
pointlist=sorted(knears.items(),key=lambda item: item[1],reverse=True)
if node.rb!=None or node.lb!=None:
if abs(aim[node.col] - node.value[node.col]) < pointlist[0][1]:
if aim[node.col]<node.value[node.col]:
traveltree(node.rb,aim)
if aim[node.col]>node.value[node.col]:
traveltree(node.lb,aim)
return pointlist
完整代码在此处取
1 import numpy as np
2 from numpy import array
3 class decisionnode:
4 def __init__(self,value=None,col=None,rb=None,lb=None):
5 self.value=value
6 self.col=col
7 self.rb=rb
8 self.lb=lb
9
10 #读取数据并将数据转换为矩阵形式
11 def readdata(filename):
12 data=open(filename).readlines()
13 x=[]
14 for line in data:
15 line=line.strip().split('\t')
16 x_i=[]
17 for num in line:
18 num=float(num)
19 x_i.append(num)
20 x.append(x_i)
21 x=array(x)
22 return x
23
24 #求序列的中值
25 def median(x):
26 n=len(x)
27 x=list(x)
28 x_order=sorted(x)
29 return x_order[n//2],x.index(x_order[n//2])
30
31 #以j列的中值划分数据,左小右大,j=节点深度%列数
32 def buildtree(x,j=0):
33 rb=[]
34 lb=[]
35 m,n=x.shape
36 if m==0: return None
37 edge,row=median(x[:,j].copy())
38 for i in range(m):
39 if x[i][j]>edge:
40 rb.append(i)
41 if x[i][j]<edge:
42 lb.append(i)
43 rb_x=x[rb,:]
44 lb_x=x[lb,:]
45 rightBranch=buildtree(rb_x,(j+1)%n)
46 leftBranch=buildtree(lb_x,(j+1)%n)
47 return decisionnode(x[row,:],j,rightBranch,leftBranch)
48
49 #搜索树:输出目标点的近邻点
50 def traveltree(node,aim):
51 global pointlist #存储排序后的k近邻点和对应距离
52 if node==None: return
53 col=node.col
54 if aim[col]>node.value[col]:
55 traveltree(node.rb,aim)
56 if aim[col]<node.value[col]:
57 traveltree(node.lb,aim)
58 dis=dist(node.value,aim)
59 if len(knears)<k:
60 knears.setdefault(tuple(node.value.tolist()),dis)#列表不能作为字典的键
61 pointlist=sorted(knears.items(),key=lambda item: item[1],reverse=True)
62 elif dis<=pointlist[0][1]:
63 knears.setdefault(tuple(node.value.tolist()),dis)
64 pointlist=sorted(knears.items(),key=lambda item: item[1],reverse=True)
65 if node.rb!=None or node.lb!=None:
66 if abs(aim[node.col] - node.value[node.col]) < pointlist[0][1]:
67 if aim[node.col]<node.value[node.col]:
68 traveltree(node.rb,aim)
69 if aim[node.col]>node.value[node.col]:
70 traveltree(node.lb,aim)
71 return pointlist
72
73 def dist(x1, x2): #欧式距离的计算
74 return ((np.array(x1) - np.array(x2)) ** 2).sum() ** 0.5
75
76 knears={}
77 k=int(input('请输入k的值'))
78 if k<2: print('k不能是1')
79 global pointlist
80 pointlist=[]
81 file=input('请输入数据文件地址')
82 data=readdata(file)
83 tree=buildtree(data)
84 tmp=input('请输入目标点')
85 tmp=tmp.split(',')
86 aim=[]
87 for num in tmp:
88 num=float(num)
89 aim.append(num)
90 aim=tuple(aim)
91 pointlist=traveltree(tree,aim)
92 for point in pointlist[-k:]:
93 print(point)