kdtree

简单demo,速度提升100倍。

跑kdtree的网上的demo,速度提升100倍,被这个惊艳到!

import time
import numpy
import scipy.spatial

def find_index_of_nearest_xy(y_array, x_array, y_point, x_point):
    distance = (y_array-y_point)**2 + (x_array-x_point)**2
    idy,idx = numpy.where(distance==distance.min())
    return idy[0],idx[0]

def do_all(y_array, x_array, points):
    store = []
    for i in range(points.shape[1]):
        store.append(find_index_of_nearest_xy(y_array,x_array,points[0,i],points[1,i]))
    return store

# Create some dummy data
y_array = numpy.random.random(10000).reshape(100,100)
x_array = numpy.random.random(10000).reshape(100,100)
points = numpy.random.random(10000).reshape(2,5000)
# Time how long it takes to run
start = time.time()
results = do_all(y_array, x_array, points)
end = time.time()
print ('Completed in: ',end-start)


######kdtree######################################################3
# Shoe-horn existing data for entry into KDTree routines
combined_x_y_arrays = numpy.dstack([y_array.ravel(),x_array.ravel()])[0]
points_list = list(points.transpose())
def do_kdtree(combined_x_y_arrays,points):
    mytree = scipy.spatial.cKDTree(combined_x_y_arrays)
    dist, indexes = mytree.query(points)
    return indexes

start = time.time()
results2 = do_kdtree(combined_x_y_arrays,points_list)
end = time.time()
print ('Completed in: ',end-start)

打印如下:

Completed in:  0.16802287101745605
Completed in:  0.0044443607330322266

Process finished with exit code 0

用坐标点构建kdtree

那么kdtree到底是能干什么,能用到我们项目中吗? kdtree是我们这么大牛让我们来实验的,他说可以用到我们这项目优化中。
然后我看了还不太好直接应用,需要重构我们现在的数据处理过程还是很麻烦的。在应用中,也是先写小demo看kdtree的用法。
我们是应用在二维坐标下计算点与点之间距离。首先需要用坐标系(x,y)来构建kdtree。

import time
import numpy as np
import scipy.spatial

x,y = np.mgrid[0:5,10:13]
tmp = x.ravel()
aa = list(zip(x.ravel(),y.ravel()))
tree = scipy.spatial.KDTree(list(zip(x.ravel(),y.ravel())))
data_tree = tree.data

import time
import numpy as np
import scipy.spatial
y_a = np.arange(0,32) * 4 + 2
x_a = np.arange(0,96) * 4 + 2
x_y = [(x, y) for x in x_a for y in y_a]
tree = scipy.spatial.KDTree(x_y)
data = tree.data


有没有更加优雅的方式生成x,y坐标点? 不太会.

pts = np.array([[0.8,0.8],[9.8,1.8]])
x, y = np.mgrid[0:12,0:5]
combined_x_y_arrays1 = np.dstack([x.ravel(),y.ravel()])[0]


y_a = np.arange(0,32) * 4 + 2
x_a = np.arange(0,32) * 4 + 1
x_y = [(x, y) for x in x_a for y in y_a]
combined_x_y_arrays2 = np.dstack([x_a.ravel(),y_a.ravel()])[0]  ##这个还需要x_a y_a个数相同,而且返回的是从x_a y_a各拿一个组成坐标

kdtree demo1:

import time
import numpy as np
import scipy.spatial
x, y = np.mgrid[0:12,0:5]
tmp = x.ravel()
aa = list(zip(x.ravel(), y.ravel())) #60个list  每个list[2]
tree = scipy.spatial.KDTree(list(zip(x.ravel(), y.ravel())))
data_tree = tree.data #np [60,2]
pts = np.array([[9.8,1.8], [2.8,0.8]]) #np [2,2]
dis, index = tree.query(pts) #dis[2]  index[2]
x_y = data_tree[52] #[2]

下图是断点图可以看到:
用shape为[60,2]坐标点构建kdtree, 然后点pts查询点与基础坐标点的关系。
kdtree返回2个值,dis是查询的点与基础坐标点的最小距离,index是查询的点与基础点索引,即查询的点与哪个基础坐标点最近。索引是data_tree的索引。
因为我看到index第一个值是52,后面我直接带进data_tree取出是哪个坐标点,是x_y[10,2]这个点。在纸上画画,9.8,1.8确实这个点与[10,2]最近,距离也确实是0.2828427!
可以啊!所以可以给个总结就是:kdtree可以查询点与基础坐标点的最近距离和最近点。这个正是我们项目中所需要的,但是需要改写我们既有工程。

kdtree query的其他参数tree.query(pts, k=5, distance_upper_bound=1)

import time
import numpy as np
import scipy.spatial
x, y = np.mgrid[0:12,0:5]
tmp = x.ravel()
aa = list(zip(x.ravel(), y.ravel()))
tree = scipy.spatial.KDTree(list(zip(x.ravel(), y.ravel())))
data_tree = tree.data
pts = np.array([[9.8,1.8], [2.8,0.8]])
dis, index = tree.query(pts, k=5, distance_upper_bound=1) #dis [2,5]  index[2,5]
x_y = data_tree[52]

可以看到:
dis, index = tree.query(pts, k=5, distance_upper_bound=1)
k表示需要查找k个最近的点,distance_upper_bound表示只查找距离在distance_upper_bound以内的点,超过这个阈值的距离不需要。
看返回的dis的shape是[2,5],表示每个点返回了5个距离,超过distance_upper_bound的赋值为了无穷大inf,index的shape也是[2,5],返回的是与该查询点最近距离的5个点索引,超过距离的是直接放的是基础点个数即最后一个不存在的点。

恩!学到这,就知道了kdtree的具体用法了,这个法子可以应用在点与点查找最小距离!

基础坐标点与查询点颠倒

在一些特殊需求场景可以颠倒基础坐标点与查询点颠倒。
因为可以看到上面kdtree返回的是查询点pts与基础坐标点距离关系。那么我需要每个基础坐标点与查询点的距离关系呢?
那么就是可以颠倒kdtree的基础点和查询点,都是灵活的,看需求怎么方便怎么来。

import time
import numpy as np
import scipy.spatial
pts = np.array([[9.8,1.8], [2.8,0.8]])
x, y = np.mgrid[0:12,0:5]
tmp = x.ravel()
aa = list(zip(x.ravel(), y.ravel()))
tree = scipy.spatial.KDTree(pts)      #tree = scipy.spatial.KDTree(list(zip(x.ravel(), y.ravel())))
data_tree = tree.data
dis, index = tree.query(list(zip(x.ravel(), y.ravel())), k=5, distance_upper_bound=1)  #dis, index = tree.query(pts, k=5, distance_upper_bound=1)
x_y = data_tree[52]

posted @ 2022-06-13 11:41  无左无右  阅读(114)  评论(0编辑  收藏  举报