[luoguSP10707] Count on a tree II
题意
原题链接
给定一棵树,节点 \(i\) 上有颜色 \(c_i\),多次询问,每次查询两点之间的路径中的不同颜色数。
sol
这是一道类似普通莫队 [luoguSP3267] D-query 的题目,但是是在树上询问的,因此考虑将树转化为序列计算。将树转化为序列包括 DFS 序,欧拉序和树链剖分三种,树链剖分复杂度更大,而 DFS 序无法解决本题,因此只能使用欧拉序。容易发现,每次询问点 \(x,y\) 会有两种情况(下设 \(rk1_i\) 表示欧拉序中第一次出现 \(i\) 的位置,\(rk2_i\) 表示欧拉序中第二次出现 \(i\) 的位置,且 \(rk1_x < rk1_y\)):
- \(\operatorname{lca}(x,y) = x\),此时路径即为从 \(x \to y\),也就是区间 \([rk1_x,rk1_y]\);
- \(\operatorname{lca}(x,y) \ne x\),此时路径为 \(x \to \operatorname{lca}(x,y) \to y\)。考虑到可能中间会遍历到其他的子树,因此这段路径对应区间 \([rk2_x,rk1_y]\),去除出现两次的点并加上 \(\operatorname{lca}(x, y)\)(因为 \(\operatorname{lca}(x,y)\) 一定会在 \(rk2_y\) 后出现)。
这样,就将树上问题转化为了序列上的问题,按照普通莫队的方法做即可。
有一个 trick,可以记录一个 \(st\) 数组表示遍历次数,如果第奇数次遍历则添加该点,否则删除该点,这样就将添加、删除、去双三个操作合并为了一个。
代码
#include <iostream>
#include <algorithm>
#include <cstring>
#include <unordered_map>
#include <cmath>
using namespace std;
const int N = 40005, M = 100005;
int h[N], e[M], ne[M], idx;
int c[N];
int eular[N * 2], rk1[N], rk2[N], timestamp;
int fa[N][16], dep[N];
int block[N], ans[M], res;
unordered_map<int, int> buc;
bool st[N];
int n, m;
struct Ask {
int l, r, lca, id;
bool operator< (const Ask &W) const {
if (block[l] != block[W.l]) return block[l] < block[W.l];
if (block[l] & 1) return r < W.r;
else return r > W.r;
}
} queries[M];
void add(int a, int b){
e[idx] = b, ne[idx] = h[a], h[a] = idx ++ ;
}
void dfs_init(int u, int father){
dep[u] = dep[father] + 1, fa[u][0] = father;
eular[ ++ timestamp] = u, rk1[u] = timestamp;
for (int i = h[u]; ~i; i = ne[i]){
int j = e[i];
if (j == father) continue;
dfs_init(j, u);
}
eular[ ++ timestamp] = u, rk2[u] = timestamp;
}
int lca(int x, int y){
if (dep[x] < dep[y]) swap(x, y);
for (int i = 15; i >= 0; i -- ) {
int fax = fa[x][i];
if (dep[fax] >= dep[y]) x = fax;
}
if (x == y) return x;
for (int i = 15; i >= 0; i -- ){
int fax = fa[x][i], fay = fa[y][i];
if (fax != fay) x = fax, y = fay;
}
return fa[x][0];
}
void calc(int x){
st[x] = !st[x];
if (st[x]) {
if (!buc[c[x]]) res ++ ;
buc[c[x]] ++ ;
} else {
buc[c[x]] -- ;
if (!buc[c[x]]) res -- ;
}
}
int main(){
memset(h, -1, sizeof h);
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i ++ ) scanf("%d", &c[i]);
for (int i = 1; i < n; i ++ ) {
int a, b;
scanf("%d%d", &a, &b);
add(a, b), add(b, a);
}
dfs_init(1, 0);
for (int k = 1; k <= 15; k ++ )
for (int i = 1; i <= n; i ++ )
fa[i][k] = fa[fa[i][k - 1]][k - 1];
for (int i = 1; i <= m; i ++ ){
int x, y;
scanf("%d%d", &x, &y);
if (rk1[x] > rk1[y]) swap(x, y);
int l = lca(x, y);
if (x == l) queries[i] = {rk1[x], rk1[y], 0, i};
else queries[i] = {rk2[x], rk1[y], l, i};
}
int Sz = sqrt(n * 2);
int bcnt = ceil(2.0 * n / Sz);
for (int i = 1; i <= bcnt; i ++ )
for (int j = (i - 1) * Sz + 1; j <= min(i * Sz, n * 2); j ++ )
block[j] = i;
sort(queries + 1, queries + m + 1);
int l = 1, r = 0;
for (int i = 1; i <= m; i ++ ){
Ask &q = queries[i];
while (l > q.l) calc(eular[ -- l]);
while (r < q.r) calc(eular[ ++ r]);
while (l < q.l) calc(eular[l ++ ]);
while (r > q.r) calc(eular[r -- ]);
if (q.lca) calc(q.lca);
ans[q.id] = res;
if (q.lca) calc(q.lca);
}
for (int i = 1; i <= m; i ++ ) printf("%d\n", ans[i]);
}
蒟蒻犯的若至错误
- 写莫队的时候
while
写成if
了 awa。