Count On A Tree II.
$n$ 个点的树,数一条链上有多少不同的点
sol:
树上莫队
首先,王室联邦分块
记 $(cu,cv)$ 为当前的链,$(qu,qv)$ 为当前询问的链,维护一个 $vis$ 数组表示“当前点在/不在当前链上”,每次暴力从 $cu,qu$ 爬到他们的 lca,从 $cv,qv$ 爬到他们的 lca,特盘一下 $qu,qv$ 的 lca 就可以了
![](https://images.cnblogs.com/OutliningIndicators/ContractedBlock.gif)
#include <bits/stdc++.h> #define LL long long using namespace std; #define rep(i, s, t) for (register int i = (s), i__end = (t); i <= i__end; ++i) #define dwn(i, s, t) for (register int i = (s), i__end = (t); i >= i__end; --i) inline int read() { int x = 0, f = 1; char ch = getchar(); for (; !isdigit(ch); ch = getchar()) if (ch == '-') f = -f; for (; isdigit(ch); ch = getchar()) x = 10 * x + ch - '0'; return x * f; } const int maxn = 540010; int n, m, b[maxn], a[maxn], blk, bcnt; vector<int> G[maxn]; int fa[maxn], dep[maxn]; namespace splca { int size[maxn], top[maxn]; void dfs1(int x) { size[x] = 1; for(auto to : G[x]) { if(to == fa[x]) continue; fa[to] = x; dfs1(to); size[x] += size[to]; } } void dfs2(int x, int col) { int k = 0; top[x] = col; for(auto to : G[x]) if(to != fa[x] && size[to] > size[k]) k = to; if(!k) return; dfs2(k, col); for(auto to : G[x]) if(to != fa[x] && to != k) dfs2(to, to); } int lca(int x, int y) { while(top[x] != top[y]) { if(dep[top[x]] < dep[top[y]]) swap(x, y); x = fa[top[x]]; } return dep[x] < dep[y] ? x : y; } } void lca_init() {splca::dfs1(1); splca::dfs2(1, 1);} int lca(int x, int y) {return splca::lca(x, y);} int size[maxn], bl[maxn], q[maxn], top; int dfs2(int x) { int cur = 0; for(auto to : G[x]) { if(to == fa[x]) continue; dep[to] = dep[x] + 1; cur += dfs2(to); if(cur >= blk) { while(cur--) bl[q[--top]] = bcnt; bcnt++; } } q[++top] = x; return cur + 1; } int ans[maxn], vis[maxn], inq[maxn]; struct Ques { int u, v, fl, fr, id; bool operator < (const Ques &b) const { return fl == b.fl ? fr < b.fr : fl < b.fl; } }qs[maxn]; int now; void move(int &x) { if(inq[x]){ if(--vis[a[x]] == 0) now--; } else if(++vis[a[x]] == 1) now++; inq[x] ^= 1; x = fa[x]; } int main() { n = read(), m = read(); blk = sqrt(n); rep(i, 1, n) b[i] = a[i] = read(); sort(b + 1, b + n + 1); rep(i, 1, n) a[i] = lower_bound(b+1, b+n+1, a[i]) - b; rep(i, 2, n) { int u = read(), v = read(); G[u].push_back(v); G[v].push_back(u); } lca_init(); dep[1] = 1; dfs2(1); while(top) bl[q[--top]] = bcnt; rep(i, 1, m) { int v = read(), u = read(); if(bl[v] > bl[u]) swap(u, v); qs[i] = (Ques){v, u, bl[v], bl[u], i}; //cout << v << " " << u << " " << bl[v] << " " << bl[u] << endl; } sort(qs + 1, qs + m + 1); int cu = 1, cv = 1; rep(i, 1, m) { int nu = qs[i].u, nv = qs[i].v; int anc = lca(cu, nu); while(cu != anc) move(cu); while(nu != anc) move(nu); anc = lca(cv, nv); while(cv != anc) move(cv); while(nv != anc) move(nv); cv = qs[i].v, cu = qs[i].u; anc = lca(cv, cu); ans[qs[i].id] = now + (!vis[a[anc]]); } rep(i, 1, m) printf("%d\n",ans[i]); }