Loading

K-D Tree 模版

K-D Tree 模版

下面这个模版(洛谷P1429为例)可以实现k维空间查询给定点的m邻近点,存储于que这个大根堆中。

#include<bits/stdc++.h>
#define ll long long
#define sq(x) (x)*(x)
#define N (250000)
using namespace std;
int idx, k, n, m, q;//idx为某个维度的编号 k为维度个数
struct Node
{
    double x[5];
    bool operator < (const Node &u) const
    {
        return x[idx] < u.x[idx];//以全局变量idx对应的维度为排序标准
    }
} P[N];
typedef pair<double, Node> PDN;//用于优先队列的排序
priority_queue<PDN> que;
struct KD_Tree
{
    int sz[N << 2]; 
    Node p[N << 2];
    void build (int i, int l, int r, int dep) {//dep初始为0 l初始为0 r初始为n - 1
        if (l > r) return;
        int mid = (l + r) >> 1;
        idx = dep % k;
        sz[i] = r - l;
        sz[i << 1] = sz[i << 1 | 1] = -1;
        nth_element(P + l, P + mid, P + r + 1);//O(n)选取中位数
        p[i] = P[mid];
        build(i << 1, l, mid - 1, dep + 1);
        build(i << 1 | 1, mid + 1, r, dep + 1);
    }

    void query (int i, int m, int dep, Node a) {//查询点为a dep为深度 m为最近邻m个
        if (sz[i] == -1) return;
        PDN tmp = PDN(0, p[i]);
        for(int j = 0; j < k; j++)
            tmp.first += 1.0 * sq(1.0 * tmp.second.x[j] - 1.0 * a.x[j]);//这里是对应的欧式距离 如果要用闵可夫斯基距离的话需要手动修改
        tmp.first = sqrt(tmp.first);
        int lc = i << 1,rc = i << 1 | 1, dim = dep % k, flag = 0;
        if (a.x[dim] >= p[i].x[dim]) swap(lc, rc);
        //如果sz[lc]为-1,~sz[lc]为0,就不进行搜索了
        if (~sz[lc]) query(lc, m, dep + 1, a);
        if (que.size() < m) que.push(tmp), flag = 1;
        else {
            if (tmp.first < que.top().first) que.pop(), que.push(tmp);
            //它的兄弟所在的区域里面可能会有比它距离查询点更近的点,所以要对它的兄弟也进行搜索
            if (sqrt(1.0 * sq(1.0 * a.x[dim] - 1.0 * p[i].x[dim])) < que.top().first) flag = 1;
        }
        if (~sz[rc] && flag) query(rc, m, dep + 1, a);
    }
} KDT;
 
int main() {
    k = 2;
    cin >> n;
    for(int i = 0; i < n; i++) {
        for(int j = 0; j < k; j++) {
            cin >> P[i].x[j];
        }
    }
    KDT.build(1, 0, n - 1, 0);
    double ans = 1e18;
    m = 2;//最近的点 但会把自己算上(自己到自己的距离为0)因此m取2
    for(int i = 0; i < n; i++) {
        KDT.query(1, m, 0, P[i]);
        if(que.top().first < ans) {
            ans = que.top().first;
        }
        while(que.size()) que.pop();
    }
    printf("%.4lf", ans);
    return 0;
}

还有一道比较好的题是洛谷P4357,查询平面第K远点对。一种做法是首先找到全局的最远点对(x, y),然后从这个点对的一个点y出发找到除(x, y)的最远点对(y, z),如此进行K次(有点类似树的直径的两次BFS),去重使用unordered_map。但上面的模版无法实现查询最远点对。题解的思路是启发式搜索降低复杂度,对于kdt的每个节点维护其超矩形边界(实际上就是二维,因为算法竞赛中基本都是二维)。等有时间可以尝试将上述模版改成可以查询平面k近/k远点对~下面先放上zero4438大佬的模版:

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<unordered_map>
#include<queue>
#define ll long long
using namespace std;
const int maxn=1e5+10;
int n,k;
int comp;
struct point
{
	int num;
	ll d[2];
	friend bool operator <(point a,point b){return a.d[comp]<b.d[comp];}
	friend ll dis(point a,point b){return (a.d[0]-b.d[0])*(a.d[0]-b.d[0])+(a.d[1]-b.d[1])*(a.d[1]-b.d[1]);}//统一把根号去掉 
};
unordered_map<int,bool>ban[maxn];//记录每个元素已经和谁配对 
struct par
{
	point x,y;ll dis;
	friend bool operator <(par a,par b){return a.dis<b.dis;}
};
priority_queue<par>q;
struct kdtree
{
	struct node
	{
		point p;ll minv[2],maxv[2];int l,r;
	};
	node t[maxn];
	point sta[maxn];
	int rt,cnt;
	par res;
	void update(int u)
	{
		int l=t[u].l,r=t[u].r;
		for(int i=0;i<=1;i++)
		{
			t[u].minv[i]=t[u].maxv[i]=t[u].p.d[i];
			if(l)t[u].minv[i]=min(t[u].minv[i],t[l].minv[i]),t[u].maxv[i]=max(t[u].maxv[i],t[l].maxv[i]);
			if(r)t[u].minv[i]=min(t[u].minv[i],t[r].minv[i]),t[u].maxv[i]=max(t[u].maxv[i],t[r].maxv[i]);
		}
	}
	int build(int l,int r,int f)
	{
		int mid=(l+r)>>1;
		comp=f;
		nth_element(sta+l,sta+mid,sta+r+1);
		int u=++cnt;
		t[u].p=sta[mid];
		if(l<mid)t[u].l=build(l,mid-1,f^1);
		if(r>mid)t[u].r=build(mid+1,r,f^1);
		update(u);return u;
	}
	ll getmax(int u,point x)
	{
		ll ret=0;
		for(int i=0;i<=1;i++)ret+=max(x.d[i]-t[u].minv[i],t[u].maxv[i]-x.d[i])*max(x.d[i]-t[u].minv[i],t[u].maxv[i]-x.d[i]);
		return ret;
	}
	void query(int u,point x)//常规 K-D tree 计算最远点 
	{
		if(ban[x.num].find(t[u].p.num)==ban[x.num].end()&&dis(t[u].p,x)>res.dis)//只有当该点没有与询问点配过对时才更新答案 
		{
			res.y=t[u].p;res.dis=dis(t[u].p,x);
		}
		int l=t[u].l,r=t[u].r;
		ll dl=0,dr=0;
		if(l)dl=getmax(l,x);if(r)dr=getmax(r,x);
		if(dl>dr)
		{
			if(dl>res.dis)query(l,x);if(dr>res.dis)query(r,x);
		}
		else 
		{
			if(dr>res.dis)query(r,x);if(dl>res.dis)query(l,x);
		}
	}
}tr;
par ans;
int main()
{
	scanf("%d%d",&n,&k);k*=2;//令 k=k*2 
	for(int i=1;i<=n;i++)scanf("%lld%lld",&tr.sta[i].d[0],&tr.sta[i].d[1]),tr.sta[i].num=i;
	tr.rt=tr.build(1,n,0);
	for(int i=1;i<=n;i++)
	{
		tr.res.x=tr.res.y=tr.sta[i];tr.res.dis=0;
		tr.query(tr.rt,tr.sta[i]);//计算最远点 
		q.push(tr.res); //放入大根堆 
		ban[tr.res.x.num][tr.res.y.num]=1;//记录已出现点对 
	}
	while(k--)
	{
		par now=q.top();q.pop();//取出当前最远点 
		ans=now;
		tr.res.x=tr.res.y=now.x;tr.res.dis=0;
		tr.query(tr.rt,now.x);//放入次远点 
		q.push(tr.res);ban[tr.res.x.num][tr.res.y.num]=1;
	}
	printf("%lld",ans.dis);
	return 0;
}
posted @ 2022-02-14 16:58  脂环  阅读(66)  评论(0编辑  收藏  举报