CF208E (树上启发式合并)
给一颗树,每次询问以v为根的子树中距v相对深度为p的节点有多少个。
- 树上倍增算出每个询问的点的祖先,问题就转换成了上述问题
- 然后就按照DSU on tree的流程。处理轻儿子贡献,处理重儿子贡献,回来计算u的贡献:暴力跑轻儿子加到重儿子计算过的cnt中,统计答案,删除轻儿子贡献。
- 这里count 算以下深度和u相同的个数就好
#include<bits/stdc++.h>
using namespace std;
#define IOS ios::sync_with_stdio(false) ,cin.tie(0), cout.tie(0)
//#pragma GCC optimize(3,"Ofast","inline")
#define ll long long
const int N = 1e5 + 5;
const int M = 1e6 + 5;
const int INF = 0x3f3f3f3f;
const ll LNF = 0x3f3f3f3f3f3f3f3f;
const int mod = 1e9 + 7;
int h[N], e[N << 1], ne[N << 1], idx;
int n, F[N][21];
int sz[N], son[N], flag, ans[N], cnt[N], deep[N];
struct node {
int len, id;
};
vector<node> Q[N];
void add(int a, int b) {
e[idx] = b, ne[idx] = h[a], h[a] = idx ++;
}
void ST_create () {
int k = 20;
for ( int j = 1; j <= k; ++ j ) {
for ( int i = 1; i <= n; ++ i ) {
F[i][j] = F[F[i][j - 1]][j - 1];
}
}
}
int find (int x, int k) {
int i = 0;
while(k) {
if(k & 1) x = F[x][i];
++ i; k >>= 1;
}
return x;
}
void dfs (int u, int fa) {
sz[u] = 1;
for ( int i = h[u]; ~i; i = ne[i] ) {
int v = e[i]; if( v == fa ) continue;
F[v][0] = u; deep[v] = deep[u] + 1;
dfs(v, u);
sz[u] += sz[v];
if(sz[v] > sz[son[u]]) son[u] = v;
}
}
void count( int u, int fa, int val ) {
cnt[deep[u]] += val;
for ( int i = h[u]; ~i; i = ne[i] ) {
int v = e[i]; if(v == fa || v == flag) continue;
//注意这里为什么不是v == son[u] continue
//我们是除了第一次进来count的那个u的重儿子其他全暴力跑一遍
count(v, u, val);
}
}
void dfs ( int u, int fa, bool keep ) {
for ( int i = h[u]; ~i; i = ne[i] ) {
int v = e[i]; if( v == fa || v == son[u] ) continue;
dfs(v, u, false);
}
if(son[u]) {
dfs(son[u], u, true);
flag = son[u];
}
count(u, fa, 1);
flag = 0;
for( auto i : Q[u] ) {
ans[i.id] = cnt[deep[u] + i.len] - 1;
}
if(!keep) {
count(u, fa, -1);
}
}
int main() {
IOS;
memset(h, -1, sizeof h);
cin >> n;
for ( int i = 1; i <= n; ++ i ) {
int u; cin >> u;
add(u, i), add(i, u);
}
dfs(0, -1);
ST_create();
int m; cin >> m;
for( int i = 1; i <= m; ++ i) {
int v, p; cin >> v >> p;
v = find(v, p);
if (!v || v == -1)
ans[i] = 0;
else Q[v].push_back({p, i});
}
dfs(0, -1, 0);
for ( int i = 1; i <= m; ++ i ) cout << ans [i] << " ";
cout << '\n';
return 0;
}