【树上莫队】【SP10707】 COT2 - Count on a tree II

Description

给定一棵 \(n\) 个点的树,每个节点有一个权值,\(m\) 次询问,每次查询两点间路径上有多少不同的权值

Input

第一行是 \(n\)\(m\)

第二行是 \(n\) 个整数描述点权

下面 \(n - 1\) 行描述这棵树

最后 \(m\) 行每行两个整数代表一次查询

Output

对每个查询输出一行一个整数代表答案

Hint

\(1~\leq~n~\leq~40000,~1~\leq~m~\leq~10^5\)。权值范围为 \([1,10^9]\)

Solution

这类数颜色题目,如果放到序列上那么妥妥的莫队,我们考虑放到树上该如何继续暴力莫队

树上问题转化成序列问题一般都需要用到遍历序,这里借助于括号遍历序,即在每进入节点 \(u\) 的时候记录 \(u\) 的序号,在完全退出 \(u\) 及其子树的时候再记录一遍 \(u\) 的序号,这样一共会被记录两次。

\(st[u]\) 为进入 \(u\) 时记录的位置对应下标,\(ed[u]\) 为退出时记录的位置对应下标。

以下不妨设 \(st[u]~<~st[v]\),即先访问 \(u\) 再访问 \(v\)

括号遍历序有一个良好的性质,对于两个点 \(u\)\(v\)\(u\)\(v\) 之间的链上的点在\(st[u]\)\(st[v]\) 之间出现且仅出现一次。同时这个条件也是必要条件,于是我们可以利用这个性质只统计链上点的信息。

考虑当 \(u\)\(v\) 的祖先时,直接处理括号遍历序上两点的信息就可以了。

考虑当他们的 LCA 不为 \(u\) 的情况,我们发现 LCA 在括号遍历序中是没有出现过的,于是我们需要特判 LCA。另外注意到 \(v\) 不在 \(u\) 的子树中,所以 \(u\) 子树中的信息毫无左右,可以直接从统计 \(ed[u]\)\(st[v]\) 的信息,而不是 \(st[u]\)\(st[v]\)

注意 ST 求 LCA 用的是欧拉遍历序而不是括号序列,所以要记两个遍历序。

Code

#include <cmath>
#include <cstdio>
#include <cstring>
#include <algorithm>
#ifdef ONLINE_JUDGE
#define freopen(a, b, c)
#endif

typedef long long int ll;

namespace IPT {
	const int L = 1000000;
	char buf[L], *front=buf, *end=buf;
	char GetChar() {
		if (front == end) {
			end = buf + fread(front = buf, 1, L, stdin);
			if (front == end) return -1;
		}
		return *(front++);
	}
}

template <typename T>
inline void qr(T &x) {
	char ch = IPT::GetChar(), lst = ' ';
	while ((ch > '9') || (ch < '0')) lst = ch, ch=IPT::GetChar();
	while ((ch >= '0') && (ch <= '9')) x = (x << 1) + (x << 3) + (ch ^ 48), ch = IPT::GetChar();
	if (lst == '-') x = -x;
}

namespace OPT {
	char buf[120];
}

template <typename T>
inline void qw(T x, const char aft, const bool pt) {
	if (x < 0) {x = -x, putchar('-');}
	int top=0;
	do {OPT::buf[++top] = static_cast<char>(x % 10 + '0');} while (x /= 10);
	while (top) putchar(OPT::buf[top--]);
	if (pt) putchar(aft);
}

const int maxn = 80010;
const int maxm = 100010;

struct Edge {
	int to;
	Edge *nxt;
};
Edge *hd[maxn];
inline void cont(int from, int to) {
	Edge *e = new Edge;
	e->to = to; e->nxt = hd[from]; hd[from] = e;
}

int n, m, vistime, sn, cnt, etime;
int dfn[maxn], st[maxn], ed[maxn], ST[18][maxn], deepth[maxn], belong[maxn], MU[maxn], bk[maxn], elur[maxn];
bool vis[maxn];

struct Ask {
	int l, r, id, ans, lca;

	inline bool operator<(const Ask &_others) const {
		if (belong[this->l] != belong[_others.l]) return this->l < _others.l;
		else if (belong[this->l] & 1) return this->r < _others.r;
		else return this->r > _others.r;
	}
};
Ask ask[maxm];

void dfs(int);
void MAKE_ST();
int cmp(int, int);
void update(int);
void add(int);
void dlt(int);
int Get_Lca(int, int);
void init_hash();
bool qwq(const Ask&, const Ask&);

int main() {
	freopen("1.in", "r", stdin);
	qr(n); qr(m);
	for (int i = 1; i <= n; ++i) qr(MU[i]);
	init_hash();
	for (int i = 1, a, b; i < n; ++i) {
		a = b = 0; qr(a); qr(b); cont(a, b); cont(b, a);
	}
	dfs(1); memset(vis, 0, sizeof vis);
	MAKE_ST();
	int sn = sqrt(vistime);
	for (int i = 1; i <= vistime; ++i) belong[i] = i / sn;
	for (int i = 1; i <= m; ++i) {
		int &l = ask[i].l, &r = ask[i].r;
		qr(l); qr(r); ask[i].id = i;
		if (st[l] > st[r]) std::swap(l, r);
		if (Get_Lca(l, r) == l) l = st[l], r = st[r];
		else {
			ask[i].lca = Get_Lca(l, r); l = ed[l]; r = st[r]; 
		}
	}
	std::sort(ask + 1, ask + 1 + m); bk[0] = 1;
	for (int i = 1, prel = ask[1].l, prer = prel - 1; i <= m; ++i) {
		int l = ask[i].l, r = ask[i].r;
		while (prel < l) update(dfn[prel++]);
		while (prel > l) update(dfn[--prel]);
		while (prer < r) update(dfn[++prer]);
		while (prer > r) update(dfn[prer--]);
		ask[i].ans = cnt; 
		if (!bk[MU[ask[i].lca]]) ++ask[i].ans;
	}
	std::sort(ask + 1, ask + 1 + m, qwq);
	for (int i = 1; i <= m; ++i) qw(ask[i].ans, '\n', true);
	return 0;
}

void dfs(int x) {
	dfn[st[x] = ++vistime] = x;
	elur[++etime] = x;
	vis[x] = true;
	for (Edge *e = hd[x]; e; e = e->nxt) if (!vis[e->to]) {
		deepth[e->to] = deepth[x] + 1;
		dfs(e->to); elur[++etime] = x;
	}
	dfn[ed[x] = ++vistime] = x;
}

void MAKE_ST() {
	for (int i = 1; i <= vistime; ++i) ST[0][i] = elur[i];
	for (int i = 1; i < 18; ++i) {
		int di = i - 1;
		for (int l = 1; l <= vistime; ++l) {
			int r = l + (1 << i) - 1; if (r > vistime) break;
			ST[i][l] = cmp(ST[di][l], ST[di][l + (1 << di)]);
		}
	}
}

inline int cmp(int x, int y) {
	if (deepth[x] < deepth[y]) return x;
	return y;
}

inline void add(int x) {
	if ((bk[x]++) == 0) ++cnt;
}

inline void dlt(int x) {
	if ((--bk[x]) == 0) --cnt;
}

void update(int x) {
	if ((vis[x] ^= 1)) add(MU[x]);
	else dlt(MU[x]);
}

int Get_Lca(int x, int y) {
	int l = st[x], r = st[y];
	int len = log2(r - l + 1);
	return cmp(ST[len][l], ST[len][r - (1 << len) + 1]);
}

void init_hash() {
	static int temp[maxn];
	memcpy(temp, MU, (n + 1) * (sizeof(int)));
	std::sort(temp + 1, temp + 1 + n);
	int *ed = std::unique(temp + 1, temp + 1 + n);
	for (int i = 1; i <= n; ++i) MU[i] = std::lower_bound(temp + 1, ed, MU[i]) - temp;
}

inline bool qwq(const Ask &_a, const Ask &_b) {
	return _a.id < _b.id;
}
posted @ 2019-02-23 16:11  一扶苏一  阅读(323)  评论(0编辑  收藏  举报