[学习笔记]poj3714Raid-平面最近点对变形——KDT入门
poj3714-Raid
http://poj.org/problem?id=3714
题意:平面有个\(n\)个黑点和\(n\)个白点,问最近的不同色点对的距离,\(n\leq 10^5\)。
题解
看上去和平面最近点对非常相似,嗯多了一个不同色的条件,我会分治!如果两个点颜色相同就返回一个无穷大,否则就正常算!合并呢?同样枚举左边的点,对右边的点枚举\(y_r\in[y_l-d,y_l+d]\)的点,容易证明右边的点一定不超过…等一等,能保证吗…?好像不行…?
我们想一想,比如这么造数据…假设左右两边间隔足够大,合并就能被卡到\(O(n^2)\),直接拆看来是不行…分治死了…
不过实际上这题数据太水…这么写交上去也能过…网上大部分题解也都是这么做的,要卡的话甚至让黑点全随机在左边,白点随机在右边也能卡掉(x)
KDT
卡算法到此结束…分治看起来是没戏了,如果顺着平面最近点对的做法,另一个比较容易想到的应该是随机化:把所有点随机旋转一个角度,然后往后找常数个异色点,然后另一个就是KD树啦!
KD树(k-dimension Tree)本质上是一个\(n\)个节点的二叉搜索树,节点上保存一个\(k\)维点的信息,也就是坐标\((x_1,x_2,\dots,x_k)\),对一个子区间\([l,r]\)递归地建树的时候,我们选取一个适当的维度\(d\)作为标准,选择区间中点\(m\)作为这层的树根,以保证左右子树尽量平衡。
接着剩下的所有\(x_{di}\leq x_{mi}\)的\(i\)的点就放在左半边(注意这里只要把他们都丢到一边去就行了,这一步类似归并排序归并的过程,可以做到线性复杂度)作为左子树,继续递归处理\([l,m-1]\),右子树同理。
而维度\(d\)的选取一般(我目前见过)的有两种,一般是为了保证后续查询操作的复杂度,一种是简单粗暴地让第\(d\)层就选\(d\mod k+1\)这个维度,另一种则是算每个维度上对应的方差,选取方差最大(点最分散)的维度来作为划分依据,这一步也同样是线性的,于是整个建树是\(O(n\log n)\)的。
接着就是这题需要的领域查询:另给一个点\(p\),查询\(p_1,\dots,p_n\)中离\(p\)最近的是哪个,操作起来有点像线段树上二分加上一点启发式搜索:
void query(int l,int r,int x){
if(l>r)return;
int mid=(l+r)>>1;
if(mid!=x)ans=min(ans,dist(tr[x],tr[mid]));
double disl,disr;disl=disr=inf;
if(tr[mid].ls)disl=f(tr[tr[mid].ls],tr[x]);
if(tr[mid].rs)disr=f(tr[tr[mid].rs],tr[x]);
if(disl<ans&&disr<ans){
if(disl<disr){
query(l,mid-1,x);
if(disr<ans)query(mid+1,r,x);
}else{
query(mid+1,r,x);
if(disl<ans)query(l,mid-1,x);
}
}else{
if(disl<ans)query(l,mid-1,x);
if(disr<ans)query(mid+1,r,x);
}
}
每个节点额外维护它和它子节点组成的这些点集对应横纵坐标的最大/最小值,把它看成一个大长方形,接着就能快速地算出点\(p\)到左右子树的这个“长方形”的最小距离是多少,因为到子树对应具体某个点的距离一定比到这个长方形的最小距离要大,所以我们把可以这个距离作为一个估价函数来进行搜索。
这样子做据说在点随机的情况下复杂度是\(O(\log n)\)的,不过最坏情况还是能到\(O(n)\),具体为什么我也不会证(x)
另外kdt还有像rang search的操作,不过这题没有涉及到,以及这次学KDT本来也就只是入个门,就暂时不讨论这些啦,有兴趣进一步了解的康康参考资料呀~
回到上面这道题,我们只要对所有白点建一个KDT,再对所有黑点进行查找就能得到答案啦,期望复杂度是\(O(n\log n)\)的!
#include<cstdio>
#include<cmath>
#include<algorithm>
#include<cstdlib>
using namespace std;
#define rep(i,a,b) for(int i=(a);i<=(b);i++)
const double inf=1e20;
const int N=2e5+5;
struct node
{
double x,y,L,R,D,U;
int ls,rs,d;
}tr[N];
int T,n;
double ans;
double dist(node p,node q)
{
return pow(p.x-q.x,2)+pow(p.y-q.y,2);
}
bool cmpx(node p,node q){return p.x<q.x;}
bool cmpy(node p,node q){return p.y<q.y;}
#define lson (tr[x].ls)
#define rson (tr[x].rs)
void maintain(int x)
{
tr[x].L=tr[x].R=tr[x].x;
tr[x].U=tr[x].D=tr[x].y;
if(lson)
tr[x].L=min(tr[x].L,tr[lson].L),tr[x].R=max(tr[x].R,tr[lson].R),
tr[x].D=min(tr[x].D,tr[lson].D),tr[x].U=max(tr[x].U,tr[lson].U);
if(rson)
tr[x].L=min(tr[x].L,tr[rson].L),tr[x].R=max(tr[x].R,tr[rson].R),
tr[x].D=min(tr[x].D,tr[rson].D),tr[x].U=max(tr[x].U,tr[rson].U);
}
int build(int l,int r)
{
if(l>r)return 0;
int mid=(l+r)>>1;
double agx=0,agy=0,vax=0,vay=0;
rep(i,l,r)agx+=tr[i].x,agy+=tr[i].y;
agx/=(double)(r-l+1);
agy/=(double)(r-l+1);
rep(i,l,r)vax+=pow(agx-tr[i].x,2),vay+=pow(agy-tr[i].y,2);
if(vax>vay)
tr[mid].d=1,nth_element(tr+l,tr+mid,tr+r+1,cmpx);
else
tr[mid].d=2,nth_element(tr+l,tr+mid,tr+r+1,cmpy);
tr[mid].ls=build(l,mid-1);
tr[mid].rs=build(mid+1,r);
maintain(mid);
return mid;
}
double f(node s,node x)
{
double res=0;
if(x.x<s.L)res+=pow(s.L-x.x,2);
if(x.x>s.R)res+=pow(s.R-x.x,2);
if(x.y<s.D)res+=pow(s.D-x.y,2);
if(x.y>s.U)res+=pow(s.U-x.y,2);
return res;
}
void query(int l,int r,int x)
{
if(l>r)return;
int mid=(l+r)>>1;
if(mid!=x)ans=min(ans,dist(tr[x],tr[mid]));
if(l==r)return;
double disl,disr;disl=disr=inf;
if(tr[mid].ls)disl=f(tr[tr[mid].ls],tr[x]);
if(tr[mid].rs)disr=f(tr[tr[mid].rs],tr[x]);
if(disl<ans&&disr<ans)
{
if(disl<disr)
{
query(l,mid-1,x);
if(disr<ans)query(mid+1,r,x);
}else
{
query(mid+1,r,x);
if(disl<ans)query(l,mid-1,x);
}
}else
{
if(disl<ans)query(l,mid-1,x);
if(disr<ans)query(mid+1,r,x);
}
}
int main()
{
scanf("%d",&T);
rep(tc,1,T)
{
scanf("%d",&n);
rep(i,1,n)scanf("%lf%lf",&tr[i].x,&tr[i].y);
build(1,n);
ans=inf;
rep(i,1,n)
{
scanf("%lf%lf",&tr[n+1].x,&tr[n+1].y);
query(1,n,n+1);
}
printf("%.3f\n",sqrt(ans));
}
return 0;
}
最后
这题还是花了几天时间(多少还是有点拖拉x)才做到算是让自己满意的地步,期间和几个oi/acm群的群友讨论了分治做法,然后再把它们卡掉…以及EI还提到了Voronoi diagram的\(O(n\log n)\)的做法,感谢这些靠谱的小伙伴萌。
参考资料: