NOIP模拟测试10 Problem B:模板:树上启发式合并

最开始以为是线段树合并,结果邓鸽鸽说线段树合并必死,布星。

热情的邓鸽鸽向我传授了我没有学习过的船新算法 树上启发式合并

学完之后发现就是很普通的启发式合并用到了树上而已

先说做法,给每个节点按时间轴开个动态开点线段树,节点保存种类和球数。易于发现每个点的球与它的子树有关。先把修改操作保存起来,开始处理整棵树。对于每个节点,我们需要判断这个球有没有出现在桶,因为这影响了我们对于种类的统计。于是再开一个数组存每个球最早出现的时间,只有最早出现的球才是有贡献的。

这道题的瓶颈在于答案的统计。可以发现我们算的每个点的答案都是子树答案的叠加,暴力的做法是遍历所有子树,这个复杂度是\(O(N^2)\)的,是30分暴力。

这太暴力力,,,于是就有了树上启发式合并,它没有某谷日报说的那么玄乎,其实和普通的启发式合并差不多。直接说优化方法:我们暴力算出轻儿子的答案,期间都清空辅助数组,只保留答案。重儿子我们就不用清空辅助数组了,算出答案,这样子节点的答案都算出来了,就差父亲的了。我们把轻儿子的答案直接合并到重儿子的答案,就得到了父节点的答案。

放在这道题,每次轻儿子计算后清空记小球出现时间的那个数组,然后算重儿子,之后将轻儿子上的线段树都启发式合并到重儿子的线段树上,父节点直接继承这棵线段树,不断递归这个过程。这个做法复杂度是\(O(Nlog^2N)\)级别的,就可以过了。

真难写,调了好久。

这道题学到的东西蛮多的。首先我之前一直没看动态开点线段树,学前置芝士的时候直接学了。然后我经常口胡启发式合并,但从来没写过,码力不足,这次也真的写了一次。第三,好久没写过这么复杂的题了,属实锻炼码力。

说句闲话,我之前一直WA40,一看Query没return,默认return了堆栈的top。败RP啊。。。。

#include <bits/stdc++.h>

const int N = 1e5 + 233;
int n, k[N], m, q, ecnt, head[N], color[N], disc[N];
int early[N], stk[N * 200], tp, ans[N], root[N];
struct Edge {
	int to, nxt;
} e[N << 1];
int ls[N * 200], rs[N * 200], siz[N * 200], val[N * 200];
std::vector<int> op[N];

inline void add_edge(int f, int to) {
	e[++ecnt] = {to, head[f]}, head[f] = ecnt;
}

int fa[N], sz[N], son[N];

void dfs1(int x, int f) {
	fa[x] = f, sz[x] = 1 + (int) op[x].size();
	for (int i = head[x], y = e[i].to; i; i = e[i].nxt, y = e[i].to) {
		if (y != f) {
			dfs1(y, x);
			sz[x] += sz[y];
			if (sz[y] > sz[son[x]]) son[x] = y;
		}
	}
}

int tot, rbin[N * 200], rbin_top;

int new_node() {
	if (tot + 1 < N * 200) return ++tot;
	else return rbin[rbin_top--];
}

void del_node(int x) {
	rbin[++rbin_top] = x;
}

void pushup(int p) {
	siz[p] = siz[ls[p]] + siz[rs[p]];
	val[p] = val[ls[p]] + val[rs[p]];
}

void change(int &p, int L, int R, int x, int y) {
	if (!p) p = new_node();
	if (L == R) {
		siz[p] = 1;
		if (y != -1) val[p] = y;
		return;
	}
	int mid = (L + R) >> 1;
	if (x <= mid) change(ls[p], L, mid, x, y);
	else change(rs[p], mid + 1, R, x, y);
	pushup(p);
}

void merge(int p, int &root, int L, int R) {
	if (!p) return;
	if (L == R) {
		if (early[color[L]] > L) {
			change(root, 1, m, early[color[L]], 0);
			change(root, 1, m, L, 1);
			early[color[L]] = L;
		} else if (early[color[L]] == 0) {
			change(root, 1, m, L, 1);
			early[color[L]] = L;
			stk[++tp] = color[L];
		} else {
			change(root, 1, m, L, -1);
		}
		del_node(p);
	}
	int mid = (L + R) >> 1;
	merge(ls[p], root, L, mid);
	merge(rs[p], root, mid + 1, R);
}

int query(int p, int L, int R, int bucket) {
	if (!bucket || !p) return 0;
	if (siz[p] <= bucket) return val[p];
	int mid = (L + R) >> 1, ret = 0;
	if (siz[ls[p]] < bucket) {
		ret += val[ls[p]];
		ret += query(rs[p], mid + 1, R, bucket - siz[ls[p]]);
	} else {
		ret += query(ls[p], L, mid, bucket);
	}
	return ret;
}

void clear() {
	while (tp > 0) early[stk[tp--]] = 0;
}

void solve(int x) {
	for (int i = head[x], y = e[i].to; i; i = e[i].nxt, y = e[i].to)
		if (y != son[x] && y != fa[x]) solve(y), clear(); //先解决轻儿子的答案
	if (son[x]) solve(son[x]);
	root[x] = root[son[x]];
	for (unsigned int i = 0; i < op[x].size(); i++) {
		int co = color[op[x][i]];
		if (early[co] == 0) {
			change(root[x], 1, m, op[x][i], 1);
			early[co] = op[x][i];
			stk[++tp] = co;	
		} else if (early[co] > op[x][i]) {
			change(root[x], 1, m, early[co], 0);
			change(root[x], 1, m, op[x][i], 1);
			early[co] = op[x][i];
		} else {
			change(root[x], 1, m, op[x][i], -1);
		}
	}
	for (int i = head[x], y = e[i].to; i; i = e[i].nxt, y = e[i].to) {
		if (y != fa[x] && y != son[x]) {
			merge(root[y], root[x], 1, m);
		}
	}
	ans[x] = query(root[x], 1, m, k[x]);
}

signed main() {
	scanf("%d", &n);
	for (int i = 1, x, y; i <= n - 1; i++)
		scanf("%d%d", &x, &y), add_edge(x, y), add_edge(y, x);
	for (int i = 1; i <= n; i++)
		scanf("%d", k + i);
	scanf("%d", &m);
	for (int i = 1, x, y; i <= m; i++)
		scanf("%d%d", &x, &y), disc[i] = color[i] = y, op[x].push_back(i);
	std::sort(disc + 1, disc + 1 + m);
	int QwQ = std::unique(disc + 1, disc + 1 + m) - (disc + 1);
	for (int i = 1; i <= m; i++)
		color[i] = std::lower_bound(disc + 1, disc + 1 + QwQ, color[i]) - disc;
	scanf("%d", &q);
	dfs1(1, 0);
	solve(1);
	for (int i = 1, x; i <= q; i++)
		scanf("%d", &x), printf("%d\n", ans[x]);
	return 0;
}
posted @ 2019-07-30 11:20  Gekoo  阅读(217)  评论(0编辑  收藏  举报