「题解」洛谷 SP10707 COT2 - Count on a tree II
题目
SP10707 COT2 - Count on a tree II
简化题意
给你一棵树,点有点权,问你从一个点到一个点的路径上有多少种不同的权值。
思路
树上莫队。
Code
#include <cmath>
#include <cstdio>
#include <cstring>
#include <string>
#include <iostream>
#include <algorithm>
#define MAXN 100001
int n, m, sqrn, pthn, head[MAXN], a[MAXN], num[MAXN], ans[MAXN], cnt[MAXN];
int c, dep[MAXN], lg[MAXN], fa[MAXN][21], fir[MAXN], la[MAXN], pre[MAXN], vis[MAXN];
struct query {
int x, y, lca, id;
friend bool operator < (query q1, query q2) {
if (num[q1.x] == num[q2.x]) return q1.y < q2.y;
return num[q1.x] < num[q2.x];
}
}q[MAXN];
struct lsh {
int w, id;
friend bool operator < (lsh l1, lsh l2) {
return l1.w < l2.w;
}
}b[MAXN];
struct Edge {
int next, to;
}pth[MAXN];
void add(int from, int to) {
pth[++pthn].to = to, pth[pthn].next = head[from];
head[from] = pthn;
}
void dfs(int u, int father) {
dep[u] = dep[father] + 1, fa[u][0] = father;
fir[u] = ++c, pre[c] = u;
for (int i = head[u]; i; i = pth[i].next) {
int x = pth[i].to;
if (x != father) dfs(x, u);
}
la[u] = ++c, pre[c] = u;
}
int lca(int x, int y) {
if (dep[x] < dep[y]) std::swap(x, y);
while (dep[x] > dep[y]) {
x = fa[x][lg[dep[x] - dep[y]] - 1];
}
if (x == y) return x;
for (int k = lg[dep[x]] - 1; k >= 0; --k) {
if (fa[x][k] != fa[y][k]) {
x = fa[x][k];
y = fa[y][k];
}
}
return fa[x][0];
}
void work(int now, int &t) {
vis[now] ? t -= !--cnt[a[now]] : t += !cnt[a[now]]++;
vis[now] ^= 1;
}
int main() {
scanf("%d %d", &n, &m), sqrn = 2 * n / sqrt(m);
for (int i = 1; i <= n; ++i) {
scanf("%d", &a[i]);
b[i].w = a[i], b[i].id = i;
}
for (int i = 1; i <= 2 * n; ++i) {
num[i] = (i - 1) / sqrn + 1;
}
std::sort(b + 1, b + n +1);
for (int i = 1, now = 0; i <= n; ++i) {
if (b[i].w != b[i - 1].w) ++now;
a[b[i].id] = now;
}
for (int i = 1, u, v; i < n; ++i) {
scanf("%d %d", &u, &v);
add(u, v), add(v, u);
}
dfs(1, 0);
for (int i = 1; i <= n; ++i) {
lg[i] = lg[i - 1] + ((1 << lg[i - 1]) == i);
}
for (int j = 1; (1 << j) <= n; ++j) {
for (int i = 1; i <= n; ++i) {
fa[i][j] = fa[fa[i][j - 1]][j - 1];
}
}
for (int i = 1, x, y, f; i <= m; ++i) {
scanf("%d %d", &x, &y);
f = lca(x, y), q[i].id = i;
if (fir[x] > fir[y]) std::swap(x, y);
if (f == x) {
q[i].x = fir[x];
q[i].y = fir[y];
}
else {
q[i].x = la[x];
q[i].y = fir[y];
q[i].lca = f;
}
}
std::sort(q + 1, q + m + 1);
int l = 1, r = 0, now = 0;
for (int i = 1; i <= m; ++i) {
while (l > q[i].x) work(pre[--l], now);
while (r < q[i].y) work(pre[++r], now);
while (l < q[i].x) work(pre[l++], now);
while (r > q[i].y) work(pre[r--], now);
if (q[i].lca) work(q[i].lca, now);
ans[q[i].id] = now;
if (q[i].lca) work(q[i].lca, now);
}
for (int i = 1; i <= m; ++i) printf("%d\n", ans[i]);
return 0;
}