CF208E (树上启发式合并)

给一颗树,每次询问以v为根的子树中距v相对深度为p的节点有多少个。

  • 树上倍增算出每个询问的点的祖先,问题就转换成了上述问题
  • 然后就按照DSU on tree的流程。处理轻儿子贡献,处理重儿子贡献,回来计算u的贡献:暴力跑轻儿子加到重儿子计算过的cnt中,统计答案,删除轻儿子贡献。
  • 这里count 算以下深度和u相同的个数就好
#include<bits/stdc++.h>
using namespace std;
#define IOS ios::sync_with_stdio(false) ,cin.tie(0), cout.tie(0)
//#pragma GCC optimize(3,"Ofast","inline")
#define ll long long
const int N = 1e5 + 5;
const int M = 1e6 + 5;
const int INF = 0x3f3f3f3f;
const ll LNF = 0x3f3f3f3f3f3f3f3f;
const int mod = 1e9 + 7;
int h[N], e[N << 1], ne[N << 1], idx;
int n, F[N][21];
int sz[N], son[N], flag, ans[N], cnt[N], deep[N];

struct node {
   int len, id;
};
vector<node> Q[N];
void add(int a, int b) {
   e[idx] = b, ne[idx] = h[a], h[a] = idx ++;
}
void ST_create () {
   int k = 20;
   for ( int j = 1; j <= k; ++ j ) {
       for ( int i = 1; i <= n; ++ i ) {
           F[i][j] = F[F[i][j - 1]][j - 1];
       }
   }
}
int find (int x, int k) {
   int i = 0;
   while(k) {
       if(k & 1) x = F[x][i];
       ++ i; k >>= 1;
   }
   return x;
}
void dfs (int u, int fa) {
   sz[u] = 1;
   for ( int i = h[u]; ~i; i = ne[i] ) {
       int v = e[i]; if( v == fa ) continue;
       F[v][0] = u; deep[v] = deep[u] + 1;
       dfs(v, u);
       
       sz[u] += sz[v];
       if(sz[v] > sz[son[u]]) son[u] = v;
   }
}
void count( int u, int fa, int val ) {
   cnt[deep[u]] += val;
   for ( int i = h[u]; ~i; i = ne[i] ) {
       int v = e[i]; if(v == fa || v == flag) continue; 
   	//注意这里为什么不是v == son[u] continue
   	//我们是除了第一次进来count的那个u的重儿子其他全暴力跑一遍
       count(v, u, val);
   }
}
void dfs ( int u, int fa, bool keep ) {
   for ( int i = h[u]; ~i; i = ne[i] ) {
       int v = e[i]; if( v == fa || v == son[u] ) continue;
       dfs(v, u, false);
   }
   if(son[u]) {
       dfs(son[u], u, true);
       flag = son[u];
   }
   count(u, fa, 1);
   flag = 0;
   for( auto i : Q[u] ) {
       ans[i.id] = cnt[deep[u] + i.len] - 1;
   }
   if(!keep) {
       count(u, fa, -1);
   }
}
int main() {
   IOS;
   memset(h, -1, sizeof h);
   cin >> n;
   for ( int i = 1; i <= n; ++ i ) {
       int u; cin >> u;
       add(u, i), add(i, u);
   }
   dfs(0, -1);
   ST_create();

   int m; cin >> m;
   for( int i = 1; i <= m; ++ i) {
       int v, p; cin >> v >> p;
       v = find(v, p);
       if (!v || v == -1)
   		ans[i] = 0;
       else Q[v].push_back({p, i});
   }
   
   dfs(0, -1, 0);
   for ( int i = 1; i <= m; ++ i ) cout << ans [i] << " ";
   cout << '\n'; 

   return 0;
}
posted @ 2022-03-29 17:32  qingyanng  阅读(18)  评论(0编辑  收藏  举报