Loading

【JZOJ3360】【NOI2013模拟】苹果树

题目大意

给你一棵\(n\)个点的树,每个点有一种颜色;现在有\(m\)个询问,每次询问你\(x\)\(y\)的路径上,若将\(a\)颜色视作\(b\)颜色,不同的颜色有几种。

\(n\leq 50000,m\leq 100000\)

分析

如果是把问题放到序列上:询问区间\([l,r]\)不同的颜色有几种。这个问题有两个已知的解法:

看这题的数据范围显然是让你莫队了。(雾

树上莫队的第一步,是把树上问题转换为序列问题。我们求出原树的欧拉序,可以发现这个序列有这样的性质:

将一个点在欧拉序中首次出现和第二次出现的位置分别记作\(fir_u\)\(las_u\),对于一条路径\((x,y)\)(假定\(fir_x<fir_y\))。
\(lca(x,y)=x\),那么这条路径对应欧拉序中的区间\([fir_x,fir_y]\)。但是区间中出现两次的点要去掉,因为它们不属于这条路径。
\(lca(x,y)\neq x\),那么这条路径对应欧拉序中的区间\([las_x,fir_y]\)。同样的要去掉出现两次的点,并且这个区间没有包括上\(lca\),要将\(lca\)再单独统计。

这样,树上问题就变成了序列问题。

为了不计算出现两次的点,我们开个标记数组,一个点每次出现,都把标记数组对应位置异或\(1\),那么一个点在标记数组中的值为\(1\)时才能被计算,当一个点对应的值变为\(0\)时又把它的贡献删去,这样问题便迎刃而解。再注意计算\(lca\)的答案即可。关于将\(a\)颜色视作\(b\)颜色的,只需判断区间中是否同时有\(a\)颜色和\(b\)颜色,有的话答案减\(1\),注意\(a=b\)要特判,不然要炸!

Code

#include <cmath>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;

const int N = 200007;

int n, m, col[N], ans[N], ord[N];
int tot, dfn, st[N], to[N << 1], nx[N << 1], fir[N], las[N], anc[N][17], dep[N];
void add(int u, int v) { to[++tot] = v, nx[tot] = st[u], st[u] = tot; }
void dfs(int u)
{
	fir[u] = ++dfn, ord[dfn] = u;
	for (int i = st[u]; i; i = nx[i]) if (!fir[to[i]]) anc[to[i]][0] = u, dep[to[i]] = dep[u] + 1, dfs(to[i]);
	las[u] = ++dfn, ord[dfn] = u;
}
int getlca(int u, int v)
{
	if (dep[u] < dep[v]) swap(u, v);
	for (int i = 16; i >= 0; i--) if (dep[anc[u][i]] >= dep[v]) u = anc[u][i];
	if (u == v) return u;
	for (int i = 16; i >= 0; i--) if (anc[u][i] != anc[v][i]) u = anc[u][i], v = anc[v][i];
	return anc[u][0];
}
int block, ret, be[N], tag[N], buc[N];
struct note { int l, r, id, a, b, lca; } q[N];
int cmp(note a, note b) { return be[a.l] == be[b.l] ? ((be[a.l] & 1) ? a.r < b.r : a.r > b.r) : a.l < b.l; }
void ins(int c, int v)
{
	if (v == 1) { if (!buc[c]) ret++; buc[c]++; }
	else { buc[c]--; if (!buc[c]) ret--; }
}

int main()
{
	scanf("%d%d", &n, &m);
	for (int i = 1; i <= n; i++) scanf("%d", &col[i]);
	for (int i = 1, u, v; i <= n; i++)
	{
		scanf("%d%d", &u, &v);
		if (u && v) add(u, v), add(v, u);
	}
	dep[1] = 1, dfs(1);
	for (int j = 1; j <= 16; j++) for (int i = 1; i <= n; i++) anc[i][j] = anc[anc[i][j - 1]][j - 1];
	block = sqrt(2 * n);
	for (int i = 1; i <= 2 * n; i++) be[i] = i / block + 1;
	for (int i = 1, x, y, a, b, lca; i <= m; i++)
	{
		scanf("%d%d%d%d", &x, &y, &a, &b);
		if (fir[x] > fir[y]) swap(x, y);
		lca = getlca(x, y);
		if (lca == x) q[i] = (note){fir[x], fir[y], i, a, b, 0};
		else q[i] = (note){las[x], fir[y], i, a, b, lca};
	}
	sort(q + 1, q + m + 1, cmp);
	for (int i = 1, l = 1, r = 0; i <= m; i++)
	{
		while (l < q[i].l) tag[ord[l]] ^= 1, ins(col[ord[l]], tag[ord[l]]), ++l;
		while (l > q[i].l) --l, tag[ord[l]] ^= 1, ins(col[ord[l]], tag[ord[l]]);
		while (r < q[i].r) ++r, tag[ord[r]] ^= 1, ins(col[ord[r]], tag[ord[r]]);
		while (r > q[i].r) tag[ord[r]] ^= 1, ins(col[ord[r]], tag[ord[r]]), --r;
		if (q[i].lca) ins(col[q[i].lca], 1);
		ans[q[i].id] = ret;
		if (q[i].a != q[i].b && buc[q[i].a] && buc[q[i].b]) ans[q[i].id]--;
		if (q[i].lca) ins(col[q[i].lca], 0);
	}
	for (int i = 1; i <= m; i++) printf("%d\n", ans[i]);
	return 0;
}
posted @ 2019-07-12 21:29  gz-gary  阅读(180)  评论(0编辑  收藏  举报