bzoj 3053: The Closest M Points【KD-tree】
多维KDtree板子
左右儿子的估价用mn~mx当区间,假设区间里的数都存在;k维轮着做割点
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cmath>
#include<queue>
#include<cstring>
using namespace std;
const int N=50005;
int n,k,m,rt,w,ans[15];
priority_queue<pair<int,int> >q;
struct qwe
{
int a[5];
int& operator [] (int x)
{
return a[x];
}
bool operator < (const qwe &b) const
{
return a[w]<b.a[w];
}
}a[N],b;
struct KD
{
int ls,rs;
qwe d,mn,mx;
}t[N<<2];
int read()
{
int r=0,f=1;
char p=getchar();
while(p>'9'||p<'0')
{
if(p=='-')
f=-1;
p=getchar();
}
while(p>='0'&&p<='9')
{
r=r*10+p-48;
p=getchar();
}
return r*f;
}
void minn(int &x,int y)
{
x>y?x=y:0;
}
void maxx(int &x,int y)
{
x<y?x=y:0;
}
void ud(int ro)
{
if(t[ro].ls)
for(int i=0;i<k;i++)
minn(t[ro].mn[i],t[t[ro].ls].mn[i]),maxx(t[ro].mx[i],t[t[ro].ls].mx[i]);
if(t[ro].rs)
for(int i=0;i<k;i++)
minn(t[ro].mn[i],t[t[ro].rs].mn[i]),maxx(t[ro].mx[i],t[t[ro].rs].mx[i]);
}
int build(int l,int r,int f)
{
if(l>r)
return 0;
w=f;
int mid=(l+r)>>1;
nth_element(a+l,a+mid,a+r+1);
t[mid].mn=t[mid].mx=t[mid].d=a[mid];
t[mid].ls=build(l,mid-1,(f+1)%k);
t[mid].rs=build(mid+1,r,(f+1)%k);
ud(mid);
return mid;
}
int dis(qwe a,qwe b)
{
int r=0;
for(int i=0;i<k;i++)
r+=(a[i]-b[i])*(a[i]-b[i]);
return r;
}
int wk(int ro)
{
int r=0;
for(int i=0;i<k;i++)
{
if(b[i]<t[ro].mn[i])
r+=(t[ro].mn[i]-b[i])*(t[ro].mn[i]-b[i]);
if(b[i]>t[ro].mx[i])
r+=(t[ro].mx[i]-b[i])*(t[ro].mx[i]-b[i]);
}
return r;
}
void ques(int ro,int f)
{
if(!ro)
return;
int dm=dis(t[ro].d,b),dl=t[ro].ls?wk(t[ro].ls):1e9,dr=t[ro].rs?wk(t[ro].rs):1e9;//cerr<<"OK"<<dm<<endl;
if(q.top().first>dm)
q.pop(),q.push(make_pair(dm,ro));
if(dl<dr)
{
if(dl<q.top().first)
ques(t[ro].ls,(f+1)%k);
if(dr<q.top().first)
ques(t[ro].rs,(f+1)%k);
}
else
{
if(dr<q.top().first)
ques(t[ro].rs,(f+1)%k);
if(dl<q.top().first)
ques(t[ro].ls,(f+1)%k);
}
}
int main()
{
while(~scanf("%d%d",&n,&k))
{
memset(t,0,sizeof(t));
for(int i=1;i<=n;i++)
for(int j=0;j<k;j++)
a[i][j]=read();
rt=build(1,n,0);
m=read();
while(m--)
{
for(int i=0;i<k;i++)
b[i]=read();
int s=read();
for(int i=1;i<=s;i++)
q.push(make_pair(1e9,0));
ques(rt,0);
for(int i=1;i<=s;i++)
ans[i]=q.top().second,q.pop();
printf("the closest %d points are:\n",s);
for(int i=s;i>=1;i--)
{
for(int j=0;j<k;j++)
printf("%d ",t[ans[i]].d[j]);
puts("");
}
}
}
return 0;
}