hdu 2966 In case of failure (KD-Tree)
http://acm.hdu.edu.cn/showproblem.php?pid=2966
一道KD树的题。题意是,给出n个不重合的点,求出这n个点的最邻近点的距离的平方。
什么是KE树就不介绍了,网上有许多KD的资料,做这题前先阅读材料。我的方法参考的是http://blog.csdn.net/zhjchengfeng5/article/details/7855241这个博客的代码,划分的过程直接调用STL中的nth_element,从而减少代码量。
我的做法是直接用点集数组构建线性存储的一棵KD树,然后用类似于线段树操作对点集进行划分和查找。因为题目的特殊性,于是我们可以将查找的时候,重合的(也就是距离为0的)点间的距离赋值为inf,这样子就可以直接利用查找最近点的方法找到目标点了。
代码如下:
1 #include <cstdio> 2 #include <algorithm> 3 #include <vector> 4 #include <cstring> 5 #include <iostream> 6 7 using namespace std; 8 9 typedef long long LL; 10 11 const int N = 111111; 12 struct Point { 13 LL x[3]; 14 } p[N], ori[N]; 15 int split[20], cur, dim; 16 17 bool cmp(Point a, Point b) { 18 return a.x[cur] < b.x[cur]; 19 } 20 21 #define lson l, m - 1, depth + 1 22 #define rson m + 1, r, depth + 1 23 24 void build(int l, int r, int depth) { 25 if (l >= r) return ; 26 int m = l + r >> 1; 27 cur = depth % dim; 28 nth_element(p + l, p + m, p + r + 1, cmp); 29 build(lson); 30 build(rson); 31 } 32 33 template <class T> T sqr(T x) { return x * x;} 34 const LL inf = 0x7777777777777777ll; 35 36 LL dist(Point x, Point y) { 37 LL ret = 0; 38 for (int i = 0; i < dim; i++) { 39 ret += sqr(x.x[i] - y.x[i]); 40 } 41 return ret ? ret : inf; 42 } 43 44 LL find(Point x, int l, int r, int depth) { 45 int cur = depth % dim; 46 if (l >= r) { 47 if (l == r) return dist(x, p[l]); 48 return inf; 49 } 50 int m = l + r >> 1; 51 LL ret = dist(x, p[m]), tmp; 52 if (x.x[cur] < p[m].x[cur]) { 53 tmp = find(x, lson); 54 if (tmp > sqr(x.x[cur] - p[m].x[cur])) { 55 tmp = min(tmp, find(x, rson)); 56 } 57 } else { 58 tmp = find(x, rson); 59 if (tmp > sqr(x.x[cur] - p[m].x[cur])) { 60 tmp = min(tmp, find(x, lson)); 61 } 62 } 63 return min(ret, tmp); 64 } 65 66 int main() { 67 // freopen("in", "r", stdin); 68 int n, T; 69 scanf("%d", &T); 70 while (T-- && scanf("%d", &n)) { 71 dim = 2; 72 for (int i = 0; i < n; i++) { 73 for (int j = 0; j < 2; j++) { 74 scanf("%I64d", &ori[i].x[j]); 75 } 76 p[i] = ori[i]; 77 } 78 build(0, n - 1, 0); 79 for (int i = 0; i < n; i++) { 80 printf("%I64d\n", find(ori[i], 0, n - 1, 0)); 81 } 82 } 83 return 0; 84 }
——written by Lyon