solution - 简单题(K-Dimension Tree)

solution - 简单题(K-Dimension Tree)

咕了这么久,终于可以来讲讲KDT了。

说句实话,KDT的算法是非常简单的,但是很少有人能很快的写对,总是会出现一些奇奇怪怪的BUG,我自己也写了一个下午。主要是写代码时注意结构的对称性,以及算法的模块性,一个function干一件事就行。

说了这么多,就开始讲算法吧。

#1 算法描述

考虑一种二叉树的结构,其中每一个节点有两个功能:

  • 存储这个节点\(\textbf{p}=[x_0, x_1, \cdots,x_{k-1}]\)的值
  • 存储一个 \(k\) 维区域的记录,并且通过这一个节点的左右儿子(\(\textbf p_l, \textbf p_r\))将其在某一个维度将其分为两半,形式化地,就是对于任意的节点\(\textbf a \in \{\textbf p_l \text{及其后代}\}, \textbf b \in \{\textbf p_r \text{及其后代}\}\),存在某一个维度向量\(\textbf T = [x_0,x_1,\cdots,x_{k-1}], \text{其中} x_t = 1, x_{p\not=t} = 0\),有$\textbf T \cdot \textbf a \leq \textbf T \cdot \textbf p < \textbf T \cdot \textbf b $

考虑到算法的简便性,我们人为规定 \(t\) 为这个节点的深度模 \(k\)

那么就十分简单了,可以很清楚的实现这个算法的插入,查询,删除。

但是考虑到算法的单次复杂度还是 \(\textrm O(n)\),需要优化。

我们可以使用替罪羊树的思路优化,即一但某个节点的某个左右儿子的重量大于这个节点的重量的 \(\alpha\) 倍,那么就直接重构这个树。其中 \(\alpha\) 基本在 \(0.75\) 附近最好。

下面是代码实现。

#2 代码实现细节

#2.1 节点定义

先是定义节点。

template <int D>
struct KDT {
	KDT<D>* ls = nullptr, * rs = nullptr;
	int mx[D], mn[D], pos[D];
	int val; // 这个结点的值
	int sum; // 这个节点及其子节点的值的和
	int tot; // 这个节点及其子节点的数量
	const bool operator < (KDT t) const {
		return t.val < val;
	} // 为了pair所必须的
};

其中 mx, mn 为这个节点及其子树的在所有维度的极大值与极小值。

pos 为这个节点的维度值。

其他的见注释。

你可以注意到这里使用了指针来定义。

#2.2 插入

这里写一下这个程序的伪代码:

$ \textbf {function}\ Insert (\text{这个节点的指针}nx, \text{插入的维度} \textbf d, \text{节点的值}val,\text{节点的深度depth(模过D)}) $

\(\ \ \ \ \textbf{if not exist } nx\ \textbf {then new }nx\)

$\ \ \ \ \textbf{else if } d_{val} \leq nx.pos_{val} $

$\ \ \ \ \ \ \ \ \textbf { then }nx \leftarrow insert(nx\rightarrow ls, d, val, depth + 1) $

\(\ \ \ \ \ \ \ \ \textbf { else } nx \leftarrow insert(nx \rightarrow rs, d, val, depth + 1);\)

$\ \ \ \ \textbf {if not } \text{balance} \textbf { then } rebuild(nx) $

\(\ \ \ \ update(nx)\)

$\ \ \ \ \textbf {return } nx $

\(\textbf {end function}\)

template <int D>
KDT<D>* insert(KDT<D>* nx, int d[D], int val, int depth) {
	if (depth >= D) depth -= D;
	if (nx == nullptr) {
		nx = new KDT<D>;
		for (int i = 0; i < D; i++) {
			nx->pos[i] = nx->mn[i] = nx->mx[i] = d[i];
		}
		nx->val = nx->sum = val;
		nx->tot = 1;
		return nx;
	}
	else {
		int flag = 1;
		for (int i = 0; i < D; i++) {
			if (d[i] != nx->pos[i])
				flag = 0;
		}
		if (flag) {
			nx->val += val;
			update(nx);
			return nx;
		}
		if (d[depth] <= nx->pos[depth]) {
			nx->ls = insert(nx->ls, d, val, depth + 1);
		} else {
			nx->rs = insert(nx->rs, d, val, depth + 1);
		}
		update(nx);
		int mx = 0;
		if (nx->ls != nullptr) mx = max(mx, nx->ls->tot);
		if (nx->rs != nullptr) mx = max(mx, nx->rs->tot);
		if (mx > nx->tot * alpha) {
			pair<int, KDT<D> >* arr = 
                new pair<int, KDT<D> >[nx->tot + 10];
			pl = 0;
			pia(nx, arr);
			nx = rebuild(1, pl + 1, depth, arr);
			delete[]arr;
		}
		return nx;
	}
}

这里请注意以下rebuild块的写法。

我的写法是新建一块内存存储被删除的节点 pair<int, KDT<D> >* arr = new pair<int, KDT<D> >[nx->tot + 10];

然后将其节点删除,并放至arr中。

代码如下:

template <int D>
void pia(KDT<D>* ptr, pair<int, KDT<D> >* arr) {
	if (ptr != nullptr) {
		arr[++pl].second = *ptr;
		pia(ptr->ls, arr);
		pia(ptr->rs, arr);
		delete ptr;
	}
}

最后是rebuild的一块。

代码如下:

template <int D>
KDT<D>* rebuild(int L, int R, int dep, pair<int, KDT<D> >* arr) {
	if (L >= R) return nullptr;
	if (dep > D) dep -= D;
	for (int i = L; i < R; i++) {
		arr[i].first = arr[i].second.pos[dep];
	}
	int mid = (L + R) >> 1;
	nth_element(arr + L, arr + mid, arr + R);
	KDT<D>* ret = new KDT<D>;
	*ret = arr[mid].second;
	ret->ls = rebuild(L, mid, dep + 1, arr);
	ret->rs = rebuild(mid + 1, R, dep + 1, arr);
	update(ret);
	return ret;
}

这里为了偷懒才用了系统的nth_element,否则可以不用pair数组。

#2.3查询

这一块比较简单,不予赘述。其中allin函数表示这个节点及其子节点全在所给范围之中。allout 相反。 in表示这个单独的点是否在区域中。

代码如下:

template <int D>
int get_ans(KDT<D>* nx, int mx[D], int mn[D]) {
	if (nx == nullptr) return 0;
	if (allout(nx, mx, mn)) {
		return 0;
	}
	if (allin(nx, mx, mn)) return nx->sum;
	int ret = 0;
	if (in(nx, mx, mn)) {
		ret = nx->val;
	}
	ret += get_ans(nx->ls, mx, mn);
	ret += get_ans(nx->rs, mx, mn);
	return ret;
}

#3 代码呈现

#include<cstdio>
#include<algorithm>

const double alpha = 0.75;
const int maxn = 210000;

using namespace std;

template <int d>
struct kdt {
	kdt<d>* ls = nullptr, * rs = nullptr;
	int mx[d], mn[d], pos[d];
	int val;
	int sum;
	int tot;
	const bool operator < (kdt t) const {
		return t.val < val;
	}
};

//pair <int, kdt<t> > arr[maxn];
int pl = 0;
template <int d>
void update(kdt<d>* ret) {
	ret->tot = 1;
	ret->sum = ret->val;
	for (int i = 0; i < d; i++) {
		ret->mx[i] = ret->mn[i] = ret->pos[i];
	}
	if (ret->ls != nullptr) {
		for (int i = 0; i < d; i++)
			ret->mx[i] = max(ret->mx[i], ret->ls->mx[i]),
			ret->mn[i] = min(ret->mn[i], ret->ls->mn[i]);
		ret->sum += ret->ls->sum;
		ret->tot += ret->ls->tot;
	}
	if (ret->rs != nullptr) {
		for (int i = 0; i < d; i++)
			ret->mx[i] = max(ret->mx[i], ret->rs->mx[i]),
			ret->mn[i] = min(ret->mn[i], ret->rs->mn[i]);
		ret->sum += ret->rs->sum;
		ret->tot += ret->rs->tot;
	}
}

template <int d>
void pia(kdt<d>* ptr, pair<int, kdt<d> >* arr) {
	if (ptr != nullptr) {
		arr[++pl].second = *ptr;
		pia(ptr->ls, arr);
		pia(ptr->rs, arr);
		delete ptr;
	}
}

template <int d>
kdt<d>* rebuild(int l, int r, int dep, pair<int, kdt<d> >* arr) {
	if (l >= r) return nullptr;
	if (dep > d) dep -= d;
	for (int i = l; i < r; i++) {
		arr[i].first = arr[i].second.pos[dep];
	}
	int mid = (l + r) >> 1;
	nth_element(arr + l, arr + mid, arr + r);
	kdt<d>* ret = new kdt<d>;
	*ret = arr[mid].second;
	ret->ls = rebuild(l, mid, dep + 1, arr);
	ret->rs = rebuild(mid + 1, r, dep + 1, arr);
	update(ret);
	return ret;
}

template <int d>
kdt<d>* insert(kdt<d>* nx, int d[d], int val, int depth) {
	if (depth >= d) depth -= d;
	if (nx == nullptr) {
		nx = new kdt<d>;
		for (int i = 0; i < d; i++) {
			nx->pos[i] = nx->mn[i] = nx->mx[i] = d[i];
		}
		nx->val = nx->sum = val;
		nx->tot = 1;
		return nx;
	}
	else {
		int flag = 1;
		for (int i = 0; i < d; i++) {
			if (d[i] != nx->pos[i])
				flag = 0;
		}
		if (flag) {
			nx->val += val;
			update(nx);
			return nx;
		}
		if (d[depth] < nx->pos[depth]) {
			nx->ls = insert(nx->ls, d, val, depth + 1);
		} else {
			nx->rs = insert(nx->rs, d, val, depth + 1);
		}
		update(nx);
		int mx = 0;
		if (nx->ls != nullptr) mx = max(mx, nx->ls->tot);
		if (nx->rs != nullptr) mx = max(mx, nx->rs->tot);
		if (mx > nx->tot * alpha) {
			pair<int, kdt<d> >* arr = new pair<int, kdt<d> >[nx->tot + 10];
			pl = 0;
			pia(nx, arr);
			nx = rebuild(1, pl + 1, depth, arr);
			delete[]arr;
		}
		return nx;
	}
}

template <int d>
int allin(kdt<d>* nx, int mx[d], int mn[d]) {
	for (int i = 0; i < d; i++) {
		if (nx->mx[i] > mx[i]) {
			return 0;
		}
		if (nx->mn[i] < mn[i]) {
			return 0;
		}
	}
	return 1;
}

template <int d>
int allout(kdt<d>* nx, int mx[d], int mn[d]) {
	for (int i = 0; i < d; i++) {
		if (nx->mn[i] > mx[i]) {
			return 1;
		}
		if (nx->mx[i] < mn[i]) {
			return 1;
		}
	}
	return 0;
}

template <int d>
int in(kdt<d>* nx, int mx[d], int mn[d]) {
	for (int i = 0; i < d; i++) {
		if (nx->pos[i] > mx[i]) {
			return 0;
		}
		if (nx->pos[i] < mn[i]) {
			return 0;
		}
	}
	return 1;
}

template <int d>
int get_ans(kdt<d>* nx, int mx[d], int mn[d]) {
	if (nx == nullptr) return 0;
	if (allout(nx, mx, mn)) {
		return 0;
	}
	if (allin(nx, mx, mn)) return nx->sum;
	int ret = 0;
	if (in(nx, mx, mn)) {
		ret = nx->val;
	}
	ret += get_ans(nx->ls, mx, mn);
	ret += get_ans(nx->rs, mx, mn);
	return ret;
}

int main() {
	kdt<2>* root = nullptr;
	int n;scanf("%d", &n);
	int lst_ans = 0;
	while (1) {
		int opt;
		scanf("%d", &opt);
		if (opt == 3) break;
		if (opt == 1) {
			int d[2] = { 0,0}, val;
			scanf("%d%d%d", d, d + 1, &val);
			d[0] ^= lst_ans, d[1] ^= lst_ans, val ^= lst_ans;
			root = insert(root, d, val, 0);
		}
		if (opt == 2) {
			int mx[2] = { 0,0 }, mn[2] = { 0,0 };
			scanf("%d%d%d%d", mn, mn + 1, mx, mx + 1);
			mx[0] ^= lst_ans, mx[1] ^= lst_ans;
			mn[0] ^= lst_ans, mn[1] ^= lst_ans;
			lst_ans = get_ans(root, mx, mn);
			printf("%d\n", lst_ans);
		}
	}
}
posted @ 2020-10-25 17:13  dgklr  阅读(156)  评论(0编辑  收藏  举报