BZOJ3631(树链剖分)
差不多可以说是树链剖分的模板题了,直接维护即可。
#include <bits/stdc++.h> using namespace std; #define REP(i,n) for(int i(0); i < (n); ++i) #define rep(i,a,b) for(int i(a); i <= (b); ++i) #define dec(i,a,b) for(int i(a); i >= (b); --i) #define for_edge(i,x) for(int i = H[x]; i; i = X[i]) #define LL long long #define ULL unsigned long long #define MP make_pair #define PB push_back #define FI first #define SE second #define INF 1 << 30 const int N = 300000 + 10; const int M = 10000 + 10; const int Q = 1000 + 10; const int A = 30 + 1; int E[N << 1], H[N << 1], X[N << 1]; int c[N]; int top[N]; int fa[N]; int deep[N]; int num[N]; int son[N]; int fp[N]; int p[N]; int et, pos; int a[N]; int n, x, y; inline int lowbit(int x){ return (x) & (-x);} inline int query(int x){int ret = 0; for (; x; x -= lowbit(x)) ret += c[x]; return ret;} inline void add(int x, int val){ for (; x <= n; x += lowbit(x)) c[x] += val;} inline void addedge(int a, int b){ E[++et] = b, X[et] = H[a], H[a] = et; E[++et] = a, X[et] = H[b], H[b] = et; } void dfs(int x, int pre){ deep[x] = deep[pre] + 1; fa[x] = pre; num[x] = 1; for_edge(i, x){ int v = E[i]; if (v != pre){ dfs(v, x); num[x] += num[v]; if (son[x] != -1 || num[v] > num[son[x]]) son[x] = v; } } } void getpos(int x, int sp){ top[x] = sp; p[x] = ++pos; fp[p[x]] = x; if (son[x] == -1) return; getpos(son[x], sp); for_edge(i, x){ int v = E[i]; if (v != son[x] && v != fa[x]) getpos(v, v); } } void cover(int u, int v, int val){ int f1 = top[u], f2 = top[v]; int tmp = 0; while (f1 != f2){ if (deep[f1] < deep[f2]){ swap(f1, f2); swap(u, v); } add(p[f1], val); add(p[u] + 1, -val); u = fa[f1]; f1 = top[u]; } if (deep[u] > deep[v]) swap(u, v); add(p[u], val); add(p[v] + 1, -val); } int main(){ #ifndef ONLINE_JUDGE freopen("test.txt", "r", stdin); freopen("test.out", "w", stdout); #endif scanf("%d", &n); rep(i, 1, n) scanf("%d", a + i); rep(i, 1, n - 1){ scanf("%d%d", &x, &y); addedge(x, y); } memset(son, -1, sizeof son); dfs(1, 0); getpos(1, 1); rep(i, 1, n - 1){ x = a[i], y = a[i + 1]; cover(x, y, 1); } rep(i, 1, n) if (i == a[1]) printf("%d\n", query(p[i])); else printf("%d\n", query(p[i]) - 1); return 0; }