【树上莫队】【SP10707】 COT2 - Count on a tree II
Description
给定一棵 \(n\) 个点的树,每个节点有一个权值,\(m\) 次询问,每次查询两点间路径上有多少不同的权值
Input
第一行是 \(n\) 和 \(m\)
第二行是 \(n\) 个整数描述点权
下面 \(n - 1\) 行描述这棵树
最后 \(m\) 行每行两个整数代表一次查询
Output
对每个查询输出一行一个整数代表答案
Hint
\(1~\leq~n~\leq~40000,~1~\leq~m~\leq~10^5\)。权值范围为 \([1,10^9]\)
Solution
这类数颜色题目,如果放到序列上那么妥妥的莫队,我们考虑放到树上该如何继续暴力莫队
树上问题转化成序列问题一般都需要用到遍历序,这里借助于括号遍历序,即在每进入节点 \(u\) 的时候记录 \(u\) 的序号,在完全退出 \(u\) 及其子树的时候再记录一遍 \(u\) 的序号,这样一共会被记录两次。
设 \(st[u]\) 为进入 \(u\) 时记录的位置对应下标,\(ed[u]\) 为退出时记录的位置对应下标。
以下不妨设 \(st[u]~<~st[v]\),即先访问 \(u\) 再访问 \(v\)。
括号遍历序有一个良好的性质,对于两个点 \(u\) 和 \(v\), \(u\) 和 \(v\) 之间的链上的点在\(st[u]\) 和 \(st[v]\) 之间出现且仅出现一次。同时这个条件也是必要条件,于是我们可以利用这个性质只统计链上点的信息。
考虑当 \(u\) 为 \(v\) 的祖先时,直接处理括号遍历序上两点的信息就可以了。
考虑当他们的 LCA 不为 \(u\) 的情况,我们发现 LCA 在括号遍历序中是没有出现过的,于是我们需要特判 LCA。另外注意到 \(v\) 不在 \(u\) 的子树中,所以 \(u\) 子树中的信息毫无左右,可以直接从统计 \(ed[u]\) 到 \(st[v]\) 的信息,而不是 \(st[u]\) 到 \(st[v]\)。
注意 ST 求 LCA 用的是欧拉遍历序而不是括号序列,所以要记两个遍历序。
Code
#include <cmath>
#include <cstdio>
#include <cstring>
#include <algorithm>
#ifdef ONLINE_JUDGE
#define freopen(a, b, c)
#endif
typedef long long int ll;
namespace IPT {
const int L = 1000000;
char buf[L], *front=buf, *end=buf;
char GetChar() {
if (front == end) {
end = buf + fread(front = buf, 1, L, stdin);
if (front == end) return -1;
}
return *(front++);
}
}
template <typename T>
inline void qr(T &x) {
char ch = IPT::GetChar(), lst = ' ';
while ((ch > '9') || (ch < '0')) lst = ch, ch=IPT::GetChar();
while ((ch >= '0') && (ch <= '9')) x = (x << 1) + (x << 3) + (ch ^ 48), ch = IPT::GetChar();
if (lst == '-') x = -x;
}
namespace OPT {
char buf[120];
}
template <typename T>
inline void qw(T x, const char aft, const bool pt) {
if (x < 0) {x = -x, putchar('-');}
int top=0;
do {OPT::buf[++top] = static_cast<char>(x % 10 + '0');} while (x /= 10);
while (top) putchar(OPT::buf[top--]);
if (pt) putchar(aft);
}
const int maxn = 80010;
const int maxm = 100010;
struct Edge {
int to;
Edge *nxt;
};
Edge *hd[maxn];
inline void cont(int from, int to) {
Edge *e = new Edge;
e->to = to; e->nxt = hd[from]; hd[from] = e;
}
int n, m, vistime, sn, cnt, etime;
int dfn[maxn], st[maxn], ed[maxn], ST[18][maxn], deepth[maxn], belong[maxn], MU[maxn], bk[maxn], elur[maxn];
bool vis[maxn];
struct Ask {
int l, r, id, ans, lca;
inline bool operator<(const Ask &_others) const {
if (belong[this->l] != belong[_others.l]) return this->l < _others.l;
else if (belong[this->l] & 1) return this->r < _others.r;
else return this->r > _others.r;
}
};
Ask ask[maxm];
void dfs(int);
void MAKE_ST();
int cmp(int, int);
void update(int);
void add(int);
void dlt(int);
int Get_Lca(int, int);
void init_hash();
bool qwq(const Ask&, const Ask&);
int main() {
freopen("1.in", "r", stdin);
qr(n); qr(m);
for (int i = 1; i <= n; ++i) qr(MU[i]);
init_hash();
for (int i = 1, a, b; i < n; ++i) {
a = b = 0; qr(a); qr(b); cont(a, b); cont(b, a);
}
dfs(1); memset(vis, 0, sizeof vis);
MAKE_ST();
int sn = sqrt(vistime);
for (int i = 1; i <= vistime; ++i) belong[i] = i / sn;
for (int i = 1; i <= m; ++i) {
int &l = ask[i].l, &r = ask[i].r;
qr(l); qr(r); ask[i].id = i;
if (st[l] > st[r]) std::swap(l, r);
if (Get_Lca(l, r) == l) l = st[l], r = st[r];
else {
ask[i].lca = Get_Lca(l, r); l = ed[l]; r = st[r];
}
}
std::sort(ask + 1, ask + 1 + m); bk[0] = 1;
for (int i = 1, prel = ask[1].l, prer = prel - 1; i <= m; ++i) {
int l = ask[i].l, r = ask[i].r;
while (prel < l) update(dfn[prel++]);
while (prel > l) update(dfn[--prel]);
while (prer < r) update(dfn[++prer]);
while (prer > r) update(dfn[prer--]);
ask[i].ans = cnt;
if (!bk[MU[ask[i].lca]]) ++ask[i].ans;
}
std::sort(ask + 1, ask + 1 + m, qwq);
for (int i = 1; i <= m; ++i) qw(ask[i].ans, '\n', true);
return 0;
}
void dfs(int x) {
dfn[st[x] = ++vistime] = x;
elur[++etime] = x;
vis[x] = true;
for (Edge *e = hd[x]; e; e = e->nxt) if (!vis[e->to]) {
deepth[e->to] = deepth[x] + 1;
dfs(e->to); elur[++etime] = x;
}
dfn[ed[x] = ++vistime] = x;
}
void MAKE_ST() {
for (int i = 1; i <= vistime; ++i) ST[0][i] = elur[i];
for (int i = 1; i < 18; ++i) {
int di = i - 1;
for (int l = 1; l <= vistime; ++l) {
int r = l + (1 << i) - 1; if (r > vistime) break;
ST[i][l] = cmp(ST[di][l], ST[di][l + (1 << di)]);
}
}
}
inline int cmp(int x, int y) {
if (deepth[x] < deepth[y]) return x;
return y;
}
inline void add(int x) {
if ((bk[x]++) == 0) ++cnt;
}
inline void dlt(int x) {
if ((--bk[x]) == 0) --cnt;
}
void update(int x) {
if ((vis[x] ^= 1)) add(MU[x]);
else dlt(MU[x]);
}
int Get_Lca(int x, int y) {
int l = st[x], r = st[y];
int len = log2(r - l + 1);
return cmp(ST[len][l], ST[len][r - (1 << len) + 1]);
}
void init_hash() {
static int temp[maxn];
memcpy(temp, MU, (n + 1) * (sizeof(int)));
std::sort(temp + 1, temp + 1 + n);
int *ed = std::unique(temp + 1, temp + 1 + n);
for (int i = 1; i <= n; ++i) MU[i] = std::lower_bound(temp + 1, ed, MU[i]) - temp;
}
inline bool qwq(const Ask &_a, const Ask &_b) {
return _a.id < _b.id;
}