HDU2966 In case of failure(浅谈k-d tree)

嘟嘟嘟


题意:给定\(n\)个二维平面上的点\((x_i, y_i)\),求离每一个点最近的点得距离的平方。(\(n \leqslant 1e5\)


这就是k-d tree入门题了。
k-d tree这东西跟平衡树有点像,但却不一样,而且查询的最坏复杂度是\(O(\sqrt{n})\)的。
首先推荐两篇博客:
K-D tree 数据结构
k-d tree入门
在众多博客之中算是非常好的。


先说一下建树。
建树可以理解为多维平衡树(但愿能这么叫),所以是平衡树结构。不过没有旋转等一系列复杂的操作。
然后对于每一层,我们按其中一个维度划分,并取这个维度的中位数上的点作为这个节点的值\(val\),然后把在这一维小于\(val\)的递归到左子树处理,大于的递归到右子树处理。
那么可以想象出来,如果是二维的话,每一个节点都是一个矩形。
至于我们选哪一个维度,一种比较“平衡”的做法是算出当前区间所有点在每一维上的方差\(s ^ 2\),然后选方差最大的那一维进行划分。
不过实际上很少有题会这么卡,所以轮流换一维建树就行。比如当前层是按第\(x\)维建树,总共有\(d\)维,那么他的下一层按\((x + 1) \ \ mod \ \ d\)划分就行。
优化一下,用nth_element代替sort找中位数,建树复杂度就是\(O(n log n)\)的。

In void pushup(int now)		//pushup可能写的麻烦了 
{
	for(int i = 0; i < 2; ++i)
	{
		if(t[now].ch[0])
		{
			t[now].Min[i] = min(t[now].Min[i], t[t[now].ch[0]].Min[i]);
			t[now].Max[i] = max(t[now].Max[i], t[t[now].ch[0]].Max[i]);
		}
		if(t[now].ch[1])
		{
			t[now].Min[i] = min(t[now].Min[i], t[t[now].ch[1]].Min[i]);
			t[now].Max[i] = max(t[now].Max[i], t[t[now].ch[1]].Max[i]);
		}
	}
}
In void build(int& now, int d, int L, int R)
{
	if(L > R) return;
	int mid = (L + R) >> 1;
	dim = d;
	nth_element(a + L, a + mid, a + R + 1);
	t[now = ++tcnt] = a[mid]; t[now].id = a[mid].id;
	t[now].Min[0] = t[now].Max[0] = t[now].d[0];
	t[now].Min[1] = t[now].Max[1] = t[now].d[1];
	t[now].ch[0] = t[now].ch[1] = 0;
	build(t[now].ch[0], d ^ 1, L, mid - 1);
	build(t[now].ch[1], d ^ 1, mid + 1, R);
	pushup(now);
}

建树看起来很优美,但是查询就不是了。 查询可以说是暴力+A*剪枝。 啊对了,我的查询和上面博客里的不一样,是asdfz的tyx大佬给我讲的。 一下以二维k-d tree为例: 暴力自然不必说,遍历整棵树即可。 关键是剪枝。 假如当前答案是$ans$(开成全局变量),那么如果要查询的点和矩形的最近距离都比$ans$大的话,自然不进入这个矩形(子树)查找。 ![](https://img2018.cnblogs.com/blog/1284378/201901/1284378-20190115090226183-2051604632.png) 也就是上面的蓝色线。(很显然对吧) 这里在补充一下,蓝色线的垂足可能没有点,但因为这是个估价函数,即最优的情况还比答案劣,我们就不进入这个子树。 那么怎么求这个距离呢? 其实就是把矩形确定下来。 那么对于每一维,我们维护一个min和max就行啦。 然后观察这个距离,其实就是查询点哪一维不在矩形里面,就把这一维的贡献算上。 最优复杂度显然是$O(log n)$,但我也不知道为啥最坏复杂度是$O(\sqrt{n})$。 ```c++ ll ans = INF; In ll dis(int now, ll* d) { ll ret = 0; for(int i = 0; i < 2; ++i) ret += (t[now].d[i] - d[i]) * (t[now].d[i] - d[i]); return ret; } In ll price(int now, ll* d) { ll ret = 0; for(int i = 0; i < 2; ++i) if(t[now].Max[i] < d[i]) ret += (d[i] - t[now].Max[i]) * (d[i] - t[now].Max[i]); else if(t[now].Min[i] > d[i]) ret += (d[i] - t[now].Min[i]) * (d[i] - t[now].Min[i]); return ret; }

In void query(int now, int id)
{
if(!now) return;
if(t[now].id ^ id) ans = min(ans, dis(now, b[id].d));
ll disL = price(t[now].ch[0], b[id].d), disR = price(t[now].ch[1], b[id].d);
if(disL < ans) query(t[now].ch[0], id);
if(disR < ans) query(t[now].ch[1], id); //千万不要写else if!!
}

代码还是很简单的。
</br>
k-d tree还可以支持查询最近的k个点,准备今天学学。
完整代码
```c++
#include<cstdio>
#include<iostream>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<cstdlib>
#include<cctype>
#include<vector>
#include<stack>
#include<queue>
using namespace std;
#define enter puts("") 
#define space putchar(' ')
#define Mem(a, x) memset(a, x, sizeof(a))
#define In inline
typedef long long ll;
typedef double db;
const ll INF = 1e18;
const db eps = 1e-8;
const int maxn = 1e5 + 5;
inline ll read()
{
	ll ans = 0;
	char ch = getchar(), last = ' ';
	while(!isdigit(ch)) last = ch, ch = getchar();
	while(isdigit(ch)) ans = (ans << 1) + (ans << 3) + ch - '0', ch = getchar();
	if(last == '-') ans = -ans;
	return ans;
}
inline void write(ll x)
{
	if(x < 0) x = -x, putchar('-');
	if(x >= 10) write(x / 10);
	putchar(x % 10 + '0');
}

int n, dim = 0;
struct Tree
{
	int ch[2], id;
	ll d[2], Min[2], Max[2]; 
	In bool operator < (const Tree& oth)const
	{
		return d[dim] < oth.d[dim];
	}
}t[maxn << 2], a[maxn], b[maxn];
int root, tcnt = 0;

In void pushup(int now)		//pushup可能写的麻烦了 
{
	for(int i = 0; i < 2; ++i)
	{
		if(t[now].ch[0])
		{
			t[now].Min[i] = min(t[now].Min[i], t[t[now].ch[0]].Min[i]);
			t[now].Max[i] = max(t[now].Max[i], t[t[now].ch[0]].Max[i]);
		}
		if(t[now].ch[1])
		{
			t[now].Min[i] = min(t[now].Min[i], t[t[now].ch[1]].Min[i]);
			t[now].Max[i] = max(t[now].Max[i], t[t[now].ch[1]].Max[i]);
		}
	}
}
In void build(int& now, int d, int L, int R)
{
	if(L > R) return;
	int mid = (L + R) >> 1;
	dim = d;
	nth_element(a + L, a + mid, a + R + 1);
	t[now = ++tcnt] = a[mid]; t[now].id = a[mid].id;
	t[now].Min[0] = t[now].Max[0] = t[now].d[0];
	t[now].Min[1] = t[now].Max[1] = t[now].d[1];
	t[now].ch[0] = t[now].ch[1] = 0;
	build(t[now].ch[0], d ^ 1, L, mid - 1);
	build(t[now].ch[1], d ^ 1, mid + 1, R);
	pushup(now);
}

ll ans = INF;
In ll dis(int now, ll* d)
{
	ll ret = 0;
	for(int i = 0; i < 2; ++i) ret += (t[now].d[i] - d[i]) * (t[now].d[i] - d[i]);
	return ret;
}
In ll price(int now, ll* d)
{
	ll ret = 0;
	for(int i = 0; i < 2; ++i) 
		if(t[now].Max[i] < d[i]) ret += (d[i] - t[now].Max[i]) * (d[i] - t[now].Max[i]);
		else if(t[now].Min[i] > d[i]) ret += (d[i] - t[now].Min[i]) * (d[i] - t[now].Min[i]);
	return ret;
}

In void query(int now, int id)
{
	if(!now) return;
	if(t[now].id ^ id) ans = min(ans, dis(now, b[id].d));
	ll disL = price(t[now].ch[0], b[id].d), disR = price(t[now].ch[1], b[id].d);
	if(disL < ans) query(t[now].ch[0], id);
	if(disR < ans) query(t[now].ch[1], id);			//千万不要写else if!! 
}

int main()
{
	int T = read();
	while(T--)
	{
		tcnt = 0;
		n = read();
		for(int i = 1; i <= n; ++i) a[i].d[0] = read(), a[i].d[1] = read(), a[i].id = i, b[i] = a[i];
		build(root, 0, 1, n);
		for(int i = 1; i <= n; ++i)
		{
			ans = INF;
			query(root, i);
			write(ans), enter;
		}
	}
	return 0;
}
posted @ 2019-01-15 09:13  mrclr  阅读(349)  评论(0编辑  收藏  举报