KDTree求平面最长最短点对
更新日志
前言
不会细致讲解KDT内容,如有需要,出门左转KDTree。
这篇文章以最常用的二维点集为例(包括模板),其他维度同理。
思路(优化)
我们考虑2-D Tree
,维护整个点集。
最朴素的做法是,每次都将当前节点与标准点更新答案,并进入其两个子树计算。不难发现,就是暴搜,没有意义。
引入一个重点:估价函数。
这个函数用来估计一个区间与标准点可能的最大值。注意到,这个最大值不一定真实存在,所以称之为估价。
更具体的,我们使用整个区间的最大最小坐标与标准点比较。
这个函数的具体作用在 query
部分,在进入一个子树前,先预估这个子树最理想的价值,如果最理想的情况都不比当前 res
更优,就无需进入这个区间了。
另一个优化是优先进入左右子树的顺序。我们看标准点在当前维度上位于中点的左侧或右侧,并以此优先进入可能更优的区间。
比如说获取最短点对,我们就优先进入同侧子树,反之同理。
模板
这个模板求曼哈顿距离下的最短点对,其他距离同理。
这个模板内附了动态插入节点的功能(二进制分组)。
#define chmin(a,b) a=min(a,b)
#define chmax(a,b) a=max(a,b)
#define rep(i,x,y) for(int i=x;i<=y;i++)
const int inf=0x3f3f3f3f;
const int N=6e5+5,K=2,T=20;
int n,m;
struct node{
int v[K];
int mn[K],mx[K];
int lson,rson;
}ns[N];
int ak;
bool cmp(int a,int b){return ns[a].v[ak]<ns[b].v[ak];}
inline int dis(int a,int b){
return abs(ns[a].v[0]-ns[b].v[0])+abs(ns[a].v[1]-ns[b].v[1]);
}
inline int rec(int a,int b){
int res=0;
rep(k,0,K-1)res+=max(ns[b].v[k]-ns[a].mx[k],0)+max(ns[a].mn[k]-ns[b].v[k],0);
return res;
}
struct kdtree{
int acnt;
int arr[N];
int rt[T];
void update(int x){
rep(k,0,K-1){
ns[x].mn[k]=ns[x].mx[k]=ns[x].v[k];
if(ns[x].lson)chmin(ns[x].mn[k],ns[ns[x].lson].mn[k]),chmax(ns[x].mx[k],ns[ns[x].lson].mx[k]);
if(ns[x].rson)chmin(ns[x].mn[k],ns[ns[x].rson].mn[k]),chmax(ns[x].mx[k],ns[ns[x].rson].mx[k]);
}
}
int build(int l,int r,int k=0){
int m=l+r>>1;
ak=k;nth_element(arr+l,arr+m,arr+r+1,cmp);
if(l<m)ns[arr[m]].lson=build(l,m-1,k^1);
if(m<r)ns[arr[m]].rson=build(m+1,r,k^1);
update(arr[m]);
return arr[m];
}
void append(int &x){
arr[++acnt]=x;
if(ns[x].lson)append(ns[x].lson);
if(ns[x].rson)append(ns[x].rson);
x=0;
}
void insert(int x){
acnt=0;
arr[++acnt]=x;
rep(t,0,T-1){
if(rt[t])append(rt[t]);
else{
rt[t]=build(1,acnt);
break;
}
}
}
void insert(int x,int y){
n++;
ns[n].v[0]=x;
ns[n].v[1]=y;
insert(n);
}
int query(int x,int st,int k){
int res=inf;
if(x!=st)chmin(res,dis(x,st));
if(ns[st].v[k]<ns[x].v[k]){
if(ns[x].lson&&rec(ns[x].lson,st)<res)chmin(res,query(ns[x].lson,st,k^1));
if(ns[x].rson&&rec(ns[x].rson,st)<res)chmin(res,query(ns[x].rson,st,k^1));
}else{
if(ns[x].rson&&rec(ns[x].rson,st)<res)chmin(res,query(ns[x].rson,st,k^1));
if(ns[x].lson&&rec(ns[x].lson,st)<res)chmin(res,query(ns[x].lson,st,k^1));
}
return res;
}
int query(int x,int y){
ns[0].v[0]=x;ns[0].v[1]=y;
int res=inf;
rep(t,0,T-1)
if(rt[t])res=min(res,query(rt[t],0,0));
return res;
}
}kdt;
附常用估价函数
统一规定 \(st\) 为标准点,\(now\) 为当前节点。
曼哈顿
最短
大致来说,在某一维度上,如果当前点在区间外,就取其到最近的边界的距离,否则就是 \(0\)。
\[\max(now.x-st.x_{\max},0)+\max(st.x_{\min}-now.x,0)+\max(now.y-st.y_{\max},0)+\max(st.y_{\min}-now.y,0)
\]
最长
容易理解,就是到边界的最大距离。
\[\max(|now.x-st.x_{\max}|,|now.x-st.x_{\min})|+\max(|now.y-st.y_{\max}|,|now.y-st.y_{\min}|)
\]
欧几里得
前注:考虑精度问题,均不开方。
最短
同理,区间外则到边界,区间内则为 \(0\)。
\[\max(now.x-st.x_{\max},st.x_{\min}-now.x,0)^2+\max(now.y-st.y_{\max},st.y_{\min}-now.y,0)^2
\]
最长
同理,到边界距离最大。
\[\max(now.x-st.x_{\max},now.x-st.x_{\min})^2+\max(now.y-st.y_{\max},now.y-st.y_{\min})^2
\]
例题
代码
前注:非题解,不做详细讲解
#include<bits/stdc++.h>
using namespace std;
#define chmin(a,b) a=min(a,b)
#define chmax(a,b) a=max(a,b)
#define rep(i,x,y) for(int i=x;i<=y;i++)
const int inf=0x3f3f3f3f;
const int N=6e5+5,K=2,T=20;
int n,m;
struct node{
int v[K];
int mn[K],mx[K];
int lson,rson;
}ns[N];
int ak;
bool cmp(int a,int b){return ns[a].v[ak]<ns[b].v[ak];}
inline int dis(int a,int b){
return abs(ns[a].v[0]-ns[b].v[0])+abs(ns[a].v[1]-ns[b].v[1]);
}
inline int rec(int a,int b){
int res=0;
rep(k,0,K-1)res+=max(ns[b].v[k]-ns[a].mx[k],0)+max(ns[a].mn[k]-ns[b].v[k],0);
return res;
}
struct kdtree{
int acnt;
int arr[N];
int rt[T];
void update(int x){
rep(k,0,K-1){
ns[x].mn[k]=ns[x].mx[k]=ns[x].v[k];
if(ns[x].lson)chmin(ns[x].mn[k],ns[ns[x].lson].mn[k]),chmax(ns[x].mx[k],ns[ns[x].lson].mx[k]);
if(ns[x].rson)chmin(ns[x].mn[k],ns[ns[x].rson].mn[k]),chmax(ns[x].mx[k],ns[ns[x].rson].mx[k]);
}
}
int build(int l,int r,int k=0){
int m=l+r>>1;
ak=k;nth_element(arr+l,arr+m,arr+r+1,cmp);
if(l<m)ns[arr[m]].lson=build(l,m-1,k^1);
if(m<r)ns[arr[m]].rson=build(m+1,r,k^1);
update(arr[m]);
return arr[m];
}
void append(int &x){
arr[++acnt]=x;
if(ns[x].lson)append(ns[x].lson);
if(ns[x].rson)append(ns[x].rson);
x=0;
}
void insert(int x){
acnt=0;
arr[++acnt]=x;
rep(t,0,T-1){
if(rt[t])append(rt[t]);
else{
rt[t]=build(1,acnt);
break;
}
}
}
void insert(int x,int y){
n++;
ns[n].v[0]=x;
ns[n].v[1]=y;
insert(n);
}
int query(int x,int st,int k){
int res=inf;
if(x!=st)chmin(res,dis(x,st));
if(ns[st].v[k]<ns[x].v[k]){
if(ns[x].lson&&rec(ns[x].lson,st)<res)chmin(res,query(ns[x].lson,st,k^1));
if(ns[x].rson&&rec(ns[x].rson,st)<res)chmin(res,query(ns[x].rson,st,k^1));
}else{
if(ns[x].rson&&rec(ns[x].rson,st)<res)chmin(res,query(ns[x].rson,st,k^1));
if(ns[x].lson&&rec(ns[x].lson,st)<res)chmin(res,query(ns[x].lson,st,k^1));
}
return res;
}
int query(int x,int y){
ns[0].v[0]=x;ns[0].v[1]=y;
int res=inf;
rep(t,0,T-1)
if(rt[t])res=min(res,query(rt[t],0,0));
return res;
}
}kdt;
int main(){
ios::sync_with_stdio(false);
cin.tie(0);cout.tie(0);
cin>>n>>m;
rep(i,1,n){
cin>>ns[i].v[0]>>ns[i].v[1];
kdt.insert(i);
}
int t,x,y;
while(m--){
cin>>t>>x>>y;
if(t==1)kdt.insert(x,y);
if(t==2)cout<<kdt.query(x,y)<<"\n";
}
return 0;
}