[学习笔记]K-D Tree

推荐:

k-d tree算法

 

对于D维的点若干,多次查询距离某个点第K大的点是什么。

处理这一类问题的一个数据结构,叫K-D Tree

 

基本思想是对点进行区域分块处理。

图示:

K-D Tree是一个二叉树。

每个点维护的信息是,

split :分裂坐标轴

ls、rs:左右儿子

node:该节点存储的真实点

 

建树:

递归建树。类似平衡树

选择当前区域的点的各个维度的方差最大的维度(传说如果方差大,数据分散,复杂度或者精度有所保证??),把这个维度当做split

这个节点的真实点就是c[mid]

然后,把这个维度[s]小于c[mid][s]的放在左边,大于的放在右边。

(实现时,用一个nth_element,再重载小于号,可以O(n)实现把中间的放在mid位置上,并且这个维度[s]小于c[mid][s]的放在左边,大于的放在右边。)

然后递归建树即可。

 

x,y是split

 

这样,整个K-D Tree就把一些点分成了若干个块。

我们一块一块处理会比较容易剪枝。

K=1

查询:最近的点(即K=1)

本质是爆搜+剪枝。。。

设查询距离点st的最近的点to

设距离为now

法一:

不断通过当前split维度和st这个维度的大小比较,

我们先走st所属的块,

回溯回来之后,

由于可能在另一半有更近的点。

如果分界线到st的该维度距离小于now

那么再走另外一个块搜索。

 

法二:

上面那个剪枝比较粗糙。

我们发现,一个块的所有点,其实可以用一个矩形框住。

那么,如果st到这个矩形可能的最近点距离小于now的话,再搜下去。

具体来说,我们每个节点维护这个节点代表的块内,最大最小的x,y坐标。(其实就是矩形四个顶点)

这个最短距离:

x差为:dx=max(st.x-x.mxx,0)+max(x.min-st.x,0)

画个图理解下,如果st的x在mix,mxx之间的话,那么x差就认为是0

如果在mix左边,那么就是x-mix,

如果在mxx右边,那么就是mxx-x

y同理。

dis=sqrt(dx*dx+dy*dy)

发现,当st在所属的块内时,dis一定是0

然后就可以剪枝了。

对于两个儿子,选择估价距离较小的那个先搜,回溯时,如果另一个距离还比now小的话,再搜另一个。

理论上,应该比法一多减一些枝。

 

总之,复杂度不明。

传说最差O(k*n^(1-1/k))每次(k是维度)

 

例题:

平面最近点对(加强版)

这个题用分治是最好的。

我们用KD树来试试。

 

枚举所有的点,找到与它最近的点距离,然后所有距离取min即可

直接做就好了。

 

注意:

1.如果写法二,那么对于0号节点的哨兵必须mi=inf,mx=-inf。否则剪枝就挂了。

2.建树的时候,build返回节点编号,不能返回计数器tot。。。。

 

这个题法一更快??

可能数据水,然后法二常数大吧。。。

 

法一:

#include<bits/stdc++.h>
#define reg register int
#define il inline
#define numb (ch^'0')
using namespace std;
typedef long long ll;
il void rd(int &x){
    char ch;x=0;bool fl=false;
    while(!isdigit(ch=getchar()))(ch=='-')&&(fl=true);
    for(x=numb;isdigit(ch=getchar());x=x*10+numb);
    (fl==true)&&(x=-x);
}
namespace Miracle{
const int N=200000+5;
const double inf=2333333333.00;
int n;
struct po{
    double x,y;
    int id;
    po(){}
    po(double xx,double yy){
        x=xx;y=yy;
    }
}a[N],c[N],st,to;
bool cmp1(po a,po b){
    return a.x<b.x;
}
bool cmp2(po a,po b){
    return a.y<b.y;
}
double ans;
double now;
double dis(po a,po b){
    return sqrt((a.x-b.x)*(a.x-b.x)+(a.y-b.y)*(a.y-b.y));
}
struct tr{
    int sp;
    po O;
    int ls,rs;
}t[2*N];
int tot;
int rt;
int build(int l,int r){
    if(l>r){
        return 0;
    }
    if(l==r){
        ++tot;
        t[tot].O=c[l];
        t[tot].ls=t[tot].rs=0;
        t[tot].sp=1;
        return tot;
    }
    int mid=(l+r)>>1;
    double ax=0,ay=0;
    for(reg i=l;i<=r;++i) ax+=c[i].x,ay+=c[i].y;
    ax/=(r-l+1);ay/=(r-l+1);
    double fx=0,fy=0;
    for(reg i=l;i<=r;++i) fx+=(c[i].x-ax)*(c[i].x-ax),fy+=(c[i].y-ay)*(c[i].y-ay);
    fx/=(r-l+1);fy/=(r-l+1);
    int ret=++tot;
    if(fx>fy){//choose x;
        nth_element(c+l,c+mid,c+r+1,cmp1);
        t[ret].sp=1;
    }else{
        nth_element(c+l,c+mid,c+r+1,cmp2);
        t[ret].sp=2;
    }
    t[ret].O=c[mid];
    t[ret].ls=build(l,mid-1);
    t[ret].rs=build(mid+1,r);
    return ret;
}
void dfs(int x){
    if(!x) return;
    if(st.id!=t[x].O.id&&dis(st,t[x].O)<now){
        now=dis(st,t[x].O);
    }
    if(t[x].sp==1){
        double d=fabs(t[x].O.x-st.x);
        if(st.x<=t[x].O.x){
            dfs(t[x].ls);
            if(d<now) dfs(t[x].rs);
        }
        else{
            dfs(t[x].rs);
            if(d<now) dfs(t[x].ls);
        }
    }
    else{
        double d=fabs(t[x].O.y-st.y);
        if(st.y<=t[x].O.y){
            dfs(t[x].ls);
            if(d<now) dfs(t[x].rs);
        }
        else{
            dfs(t[x].rs);
            if(d<now) dfs(t[x].ls);
        }
    }
}
int main(){
    scanf("%d",&n);
    for(reg i=1;i<=n;++i){
        scanf("%lf%lf",&a[i].x,&a[i].y);
        a[i].id=i;
        c[i]=a[i];
    }
    rt=build(1,n);
    ans=inf;
    for(reg i=1;i<=n;++i){
        st=a[i];
        now=inf;
        to=po(inf,inf);
        dfs(1);
        ans=min(ans,now);
    }
    printf("%.4lf",ans);
    return 0;
}

}
int main(){
    Miracle::main();
    return 0;
}

/*
   Author: *Miracle*
   Date: 2018/11/26 8:43:17
*/
法一

法二:

#include<bits/stdc++.h>
#define reg register int
#define il inline
#define numb (ch^'0')
using namespace std;
typedef long long ll;
il void rd(int &x){
    char ch;x=0;bool fl=false;
    while(!isdigit(ch=getchar()))(ch=='-')&&(fl=true);
    for(x=numb;isdigit(ch=getchar());x=x*10+numb);
    (fl==true)&&(x=-x);
}
namespace Miracle{
const int N=200000+5;
const double inf=2333333333.00;
int n;
struct po{
    double x,y;
    int id;
    po(){}
    po(double xx,double yy){
        x=xx;y=yy;
    }
}a[N],c[N],st,to;
bool cmp1(po a,po b){
    return a.x<b.x;
}
bool cmp2(po a,po b){
    return a.y<b.y;
}
double ans;
double now;
double dis(po a,po b){
    return sqrt((a.x-b.x)*(a.x-b.x)+(a.y-b.y)*(a.y-b.y));
}
struct tr{
    double mxx,mix,mxy,miy;
    int sp;
    po O;
    int ls,rs;
}t[2*N];
int tot;
int rt;
int build(int l,int r){
    if(l>r){
        return 0;
    }
    if(l==r){
        ++tot;
        t[tot].mxx=t[tot].mix=c[l].x;
        t[tot].mxy=t[tot].miy=c[l].y;
        t[tot].O=c[l];
        t[tot].ls=t[tot].rs=0;
        t[tot].sp=1;
        return tot;
    }
    int mid=(l+r)>>1;
    double ax=0,ay=0;
    for(reg i=l;i<=r;++i) ax+=c[i].x,ay+=c[i].y;
    ax/=(r-l+1);ay/=(r-l+1);
    double fx=0,fy=0;
    for(reg i=l;i<=r;++i) fx+=(c[i].x-ax)*(c[i].x-ax),fy+=(c[i].y-ay)*(c[i].y-ay);
    fx/=(r-l+1);fy/=(r-l+1);
    int ret=++tot;
    if(fx>fy){//choose x;
        nth_element(c+l,c+mid,c+r+1,cmp1);
        t[ret].sp=1;
    }else{
        nth_element(c+l,c+mid,c+r+1,cmp2);
        t[ret].sp=2;
    }
    t[ret].O=c[mid];
    t[ret].ls=build(l,mid-1);
    t[ret].rs=build(mid+1,r);
    t[ret].mxx=max(t[t[ret].rs].mxx,t[t[ret].ls].mxx);
    t[ret].mix=min(t[t[ret].rs].mix,t[t[ret].ls].mix);
    t[ret].mxy=max(t[t[ret].rs].mxy,t[t[ret].ls].mxy);
    t[ret].miy=min(t[t[ret].rs].miy,t[t[ret].ls].miy);
    //cout<<" ret "<<ret<<" "<<l<<" "<<r<<endl;
    return ret;
}
void dfs(int x){
    if(st.id!=t[x].O.id&&dis(st,t[x].O)<now){
        now=dis(st,t[x].O);
        to=t[x].O;
    }
    if(t[x].ls&&t[x].rs){
        double lx=max(st.x-t[t[x].ls].mxx,0.0)+max(t[t[x].ls].mix-st.x,0.0);
        double ly=max(st.y-t[t[x].ls].mxy,0.0)+max(t[t[x].ls].miy-st.y,0.0);
        double len1=sqrt(lx*lx+ly*ly);
        double rx=max(st.x-t[t[x].rs].mxx,0.0)+max(t[t[x].rs].mix-st.x,0.0);
        double ry=max(st.y-t[t[x].rs].mxy,0.0)+max(t[t[x].rs].miy-st.y,0.0);
        double len2=sqrt(rx*rx+ry*ry);
        if(len1<=len2&&len1<now){
            dfs(t[x].ls);
            if(len2<now) 
                dfs(t[x].rs);
        }
        else if(len2<=len1&&len2<now){
            dfs(t[x].rs);
            if(len1<now) 
                dfs(t[x].ls);
        }
    }
    else if(t[x].ls){
        double lx=max(st.x-t[t[x].ls].mxx,0.0)+max(t[t[x].ls].mix-st.x,0.0);
        double ly=max(st.y-t[t[x].ls].mxy,0.0)+max(t[t[x].ls].miy-st.y,0.0);
        double len1=sqrt(lx*lx+ly*ly);
        if(len1<now) 
        dfs(t[x].ls);
    }
    else if(t[x].rs){
        double rx=max(st.x-t[t[x].rs].mxx,0.0)+max(t[t[x].rs].mix-st.x,0.0);
        double ry=max(st.y-t[t[x].rs].mxy,0.0)+max(t[t[x].rs].miy-st.y,0.0);
        double len2=sqrt(rx*rx+ry*ry);
        if(len2<now) 
        dfs(t[x].rs);
    }
    else return;
}
int main(){
    scanf("%d",&n);
    t[0].mix=inf;t[0].mxx=-inf;
    t[0].miy=inf;t[0].mxy=-inf;
    for(reg i=1;i<=n;++i){
        scanf("%lf%lf",&a[i].x,&a[i].y);
        a[i].id=i;
        c[i]=a[i];
    }
    rt=build(1,n);
    ans=inf;
    for(reg i=1;i<=n;++i){
    //    cout<<" ii "<<i<<" : "<<a[i].x<<" "<<a[i].y<<" ------------------ "<<endl;
        st=a[i];
        now=inf;
        to=po(inf,inf);
        dfs(1);
    //    cout<<" after "<<now<<endl;
        ans=min(ans,now);
    }
    printf("%.4lf",ans);
    return 0;
}

}
int main(){
    Miracle::main();
    return 0;
}

/*
   Author: *Miracle*
   Date: 2018/11/26 8:43:17
*/
法二

 

 

 对了,KD-Tree其实也可以不记录左右儿子,以及代表实际点

因为,每次我们选择的是mid位置的点,之后这个点的位置也不会再动了。

而左右儿子区间也是定值。

所以,query时记录(l,r)即可,访问实际点的话,直接取c[mid]就好。

 


 

upda:2019.2.14

蒟蒻博主学了半天忘了学一般情况下的K了。。。

K任意

K比较小、前K这种:

求最远点为例,

用一个大小为K的小根堆,

对于当前区间的代表点,

如果q.size()<K那么放进去

否则如果dis>q.top()那么q.pop()然后放进去

左右儿子:先找距离大的,如果q.size()<K或者估价函数估计会有比堆顶更远的,就去搜。

正确性:

每次保留当前最大的K个,只要证明最大的K个都能取到就是正确的

由于估价函数存在,如果某个最大K个不在堆中,一定比堆顶优,会去找过去的。

复杂度:玄学。剪枝怎么优秀怎么剪就是了。

 

例题:

 

bzoj 2626: JZPFAR

 

CQOI2016 K远点对

(把每个点都来像上面那样找一遍,堆的大小是2*K)

// luogu-judger-enable-o2
#include<bits/stdc++.h>
#define reg register int
#define il inline
#define fi first
#define se second
#define mk(a,b) make_pair(a,b)
#define numb (ch^'0')
#define pb push_back
#define solid const auto &
#define enter cout<<endl
#define pii pair<int,int>
using namespace std;
typedef long long ll;
template<class T>il void rd(T &x){
    char ch;x=0;bool fl=false;while(!isdigit(ch=getchar()))(ch=='-')&&(fl=true);
    for(x=numb;isdigit(ch=getchar());x=x*10+numb);(fl==true)&&(x=-x);}
template<class T>il void output(T x){if(x/10)output(x/10);putchar(x%10+'0');}
template<class T>il void ot(T x){if(x<0) putchar('-'),x=-x;output(x);putchar(' ');}
template<class T>il void prt(T a[],int st,int nd){for(reg i=st;i<=nd;++i) ot(a[i]);putchar('\n');}
namespace Modulo{
const int mod=998244353;
int ad(int x,int y){return (x+y)>=mod?x+y-mod:x+y;}
void inc(int &x,int y){x=ad(x,y);}
int mul(int x,int y){return (ll)x*y%mod;}
void inc2(int &x,int y){x=mul(x,y);}
int qm(int x,int y=mod-2){int ret=1;while(y){if(y&1) ret=mul(x,ret);x=mul(x,x);y>>=1;}return ret;}
}
//using namespace Modulo;
namespace Miracle{
const int N=1e5+5;
const int inf=1<<30;
int n,K;
struct po{
    int x,y;
    po(){}
    po(int xx,int yy){x=xx;y=yy;}
}p[N],st;
priority_queue<ll,vector<ll>,greater<ll> >q;
struct node{
    po a;
    int mx[2],mi[2];
    int ls,rs;
}t[N];
ll dis(po a,po b){
    return (ll)((ll)a.x-b.x)*((ll)a.x-b.x)+(ll)(a.y-b.y)*(ll)(a.y-b.y);
}
bool cmp1(po a,po b){
    return a.x<b.x;
}
bool cmp2(po a,po b){
    return a.y<b.y;
}
#define mid ((l+r)>>1)
#define ls t[x].ls
#define rs t[x].rs
void pushup(int x){
    t[x].mi[0]=t[x].mx[0]=t[x].a.x;
    t[x].mi[1]=t[x].mx[1]=t[x].a.y;
    for(reg i=0;i<2;++i){
        if(ls) t[x].mx[i]=max(t[ls].mx[i],t[x].mx[i]);
        if(rs) t[x].mx[i]=max(t[rs].mx[i],t[x].mx[i]);
        if(ls) t[x].mi[i]=min(t[ls].mi[i],t[x].mi[i]);
        if(rs) t[x].mi[i]=min(t[rs].mi[i],t[x].mi[i]);
    }
}
int build(int l,int r,int sp){
    if(l>r) return 0;
    int x=mid;
    if(l==r){
        t[x].a=p[l];
        t[x].mi[0]=t[x].mx[0]=t[x].a.x;
        t[x].mi[1]=t[x].mx[1]=t[x].a.y;
        return x;
    }
    if(sp==1) nth_element(p+l+1,p+mid+1,p+r+1,cmp1);
    else nth_element(p+l+1,p+mid+1,p+r+1,cmp2);
    t[x].a=p[mid];
    ls=build(l,mid-1,sp^1);
    rs=build(mid+1,r,sp^1);
    pushup(x);
    return x;
}
void upda(const po &a,const po &b){
    ll d=dis(a,b);
    if(q.top()<d) q.pop(),q.push(d);
}
bool ok(ll d){
    return q.top()<d;
}
ll get(int x){
    ll dx=max(abs(st.x-t[x].mi[0]),abs(st.x-t[x].mx[0]));
    ll dy=max(abs(st.y-t[x].mi[1]),abs(st.y-t[x].mx[1]));
    return dx*dx+dy*dy;
}
void query(int x,int l,int r){
    if(!x) return;
    if(l==r){
        upda(st,t[x].a);return;
    }
    ll dl=-inf,dr=-inf;
    upda(st,t[x].a);
    if(ls)dl=get(ls);
    if(rs)dr=get(rs);
    if(ok(max(dl,dr))){
        if(dl>dr) {
            query(ls,l,mid-1);
            if(ok(dr)) query(rs,mid+1,r);
        }else{
            query(rs,mid+1,r);
            if(ok(dl)) query(ls,l,mid-1);
        }
    }
}
int main(){
    rd(n);rd(K);
    K=K*2;
    for(reg i=1;i<=K;++i) q.push(0);
    for(reg i=1;i<=n;++i) rd(p[i].x),rd(p[i].y);
    int rt=build(1,n,1);
    for(reg i=1;i<=n;++i){
        st=p[i];query(rt,1,n);
    }
    ot(q.top());
    return 0;
}

}
signed main(){
    Miracle::main();
    return 0;
}

/*
   Author: *Miracle*
*/
View Code

注意,这个是平衡树的pushup!所以要考虑自己的这个点!

 

如果K达到O(N)级别

二分距离看SZ不知是不是更好一些?

(反正K很小的话,用堆的logn很小,自然快)

 

升级操作

K-Dtree也可以查询区间(矩阵区域)信息,

平衡树、线段树怎么做,它就怎么做

K-Dtree范围计数

1.

K-D 树 –范围计数:例题 1
• 给出一棵树,每次询问或修改这样一个区域:
• 在 𝑝 子树内,距离 𝑝 不超过 𝑑 的所有点

• 𝑛, 𝑚 ≤ 2 × 10^5

不修改,直接主席树

修改就要K-D tree了

一个点:坐标:(dfn[x],dep[x])前者处理子树,后者处理距离

K-Dtree上打标记+区域查询

2.

K-D 树 –范围计数:例题 2
• 给出一棵树
• 每个点都有两个数 𝑎𝑖 和 𝑏𝑖
• 每个点有个变量 𝑥𝑖,每天 𝑥𝑖 ← min 𝑏𝑖, 𝑥𝑖 + 𝑎𝑖
• 一个事件有它的发生日期,发生的时候,会选择一个例题 1 中的
形式的区域,求区域中的 𝑥𝑖 的和,然后将区域中的 𝑥𝑖 全部清零
• 对每个事件输出 𝑥𝑖 的和
• 𝑛, 𝑚 ≤ 2 × 10^5

由于 𝑥𝑖 的每次清空,每次询问的答案只和“上次修改时间”有关
• 每次是子矩形设置“上次修改时间”,是打标记
• 在打标记的时候,DFS 并回收 K-D 树子树内的所有标记
  如果一个节点有标记
  那么它子树内一定没标记(因为被回收了)
  并且这个点子树内的所有点的“上次修改时间”相同

(标记也要pushdown,查询的时候,边pushdown边找,找到完全包含于询问的区间,如果子树没有标记,直接做,否则往有标记的子树里去回收

然后集体打标记。复杂度多一个2倍的常数而已。


• 而对每个树上的点,𝑥𝑖 对修改时间差是分段一次函数
• 对每个 K-D 树节点维护子树内分段一次函数的和,对应 𝑥𝑖 的和的分段
• 如果标记回收到它,就在分段函数上二分求出 𝑥𝑖 的和
• 通过离线询问可以去掉二分(???存疑)

K-Dtree优化DP

2018.7.23 模拟赛总结

K-D 树 –更多应用

 

posted @ 2018-11-26 11:57  *Miracle*  阅读(389)  评论(0编辑  收藏  举报