SPOJ - COT2 Count on a tree II
题意;
一棵N个节点的树,有点权。M次询问,每次询问点(u,v)路径上有多少个权值不同的点。
题解:
树上开莫队,分块方法可以参照BZOJ1086题的方式。按照询问点(u,v)所在块将询问进行排序。更新路径时用vis数组标记路径上的点是否访问过。
#include <bits/stdc++.h> using namespace std; typedef long long ll; const int maxn = 4e4+10; int n, m; int u, v; int a[maxn], w[maxn], c[maxn]; int depth[maxn]; int num[maxn]; int f[maxn], fa[maxn], vis[maxn], lca_[maxn*3]; int blk, tot, top, s[maxn], bel[maxn]; int no, now_u, now_v, la_u, la_v, la_lca; int ans[maxn*3]; vector<int> g[maxn]; struct node { int to, id; node(int a, int b) { to = a; id = b; } }; vector<node> q[maxn]; void add_edge(int u, int v) { g[u].push_back(v); g[v].push_back(u); } int find(int x) { return x==f[x]?x:f[x] = find(f[x]); } struct ask { int l, r, id; ask(int a, int b, int c) { l = a; r = b; id = c; } ask() {} bool operator < (const ask &a)const { if(bel[l]==bel[a.l]) return bel[r] < bel[a.r]; return bel[l] < bel[a.l]; } }; ask as[maxn*3]; void dfs(int u, int pre, int d) { fa[u] = pre; depth[u] = d; int tmp = top; int len = g[u].size(); for(int i = 0; i < len; i++) { if(g[u][i]==pre) continue; dfs(g[u][i], u, d+1); if(top-tmp >= blk) { tot++; while (top != tmp) bel[s[top--]] = tot; } } s[++top] = u; } void tarjan(int u, int pre) { int len = g[u].size(); for(int i = 0; i < len; i++) { int v = g[u][i]; if(v==pre) continue; tarjan(v, u); f[v] = u; } vis[u] = 1; int lenn = q[u].size(); for(int i = 0; i < lenn; i++) { int v = q[u][i].to; if(vis[v]) lca_[q[u][i].id] = find(v); } } void xornode(int x) { if(vis[x]) { vis[x]--; num[a[x]]--; if(!num[a[x]]) no--; } else { vis[x]++; if(!num[a[x]]) no++; num[a[x]]++; } } void xorpath(int u, int u_to) { if(depth[u] < depth[u_to]) swap(u, u_to); while(depth[u] > depth[u_to]) { xornode(u); u = fa[u]; } while(u != u_to) { xornode(u); xornode(u_to); u = fa[u]; u_to = fa[u_to]; } } int main() { scanf("%d%d", &n, &m); for(int i = 1; i <= n; i++) { f[i] = i; scanf("%d", &w[i]); c[i] = w[i]; } sort(c+1, c+n+1); int num = unique(c+1, c+n+1)-c; for(int i = 1; i <= n; i++) a[i] = lower_bound(c+1, c+num, w[i])-c; for(int i = 1; i < n; i++) { scanf("%d%d", &u, &v); add_edge(u, v); } blk = sqrt(n); dfs(1, 0, 0); for(int i = 1; i <= m; i++) { scanf("%d%d", &u, &v); if(bel[u] > bel[v]) swap(u, v); q[u].push_back(node(v, i)); q[v].push_back(node(u, i)); as[i] = ask(u, v, i); } sort(as+1, as+m+1); tarjan(1, 0); memset(vis, 0, sizeof(vis)); la_u = la_v = la_lca = 1; xornode(1); for(int i = 1; i <= m; i++) { xorpath(as[i].l, la_u); xorpath(as[i].r, la_v); xornode(la_lca); xornode(lca_[as[i].id]); la_u = as[i].l; la_v = as[i].r; la_lca = lca_[as[i].id]; ans[as[i].id] = no; } for(int i = 1; i <= m; i++) printf("%d\n", ans[i]); }