P3066 [USACO12DEC]Running Away From the Barn G——树上差分
题面
给定一颗 \(n\) 个点的有根树,边有边权,节点从 \(1\) 至 \(n\) 编号,\(1\) 号节点是这棵树的根。
再给出一个参数 \(t\),对于树上的每个节点 \(u\),请求出 \(u\)的子树中有多少节点满足该节点到 \(u\) 的距离不大于 \(t\)
数据范围
- \(1 \leq n \leq 2 \times 10^5,1 \leq t \leq 10^{18}\)。
- \(1 \leq p_i \lt i,1 \leq w_i \leq 10^{12}\)。
思路
子树,我们会想到什么,树上差分
考虑每个节点的贡献,每个节点对它到一个距离小于的L的最远祖先路径上所有的点都产生1的贡献(包括自己这个点)
那么我们肯定不能暴力跳来找祖先,我们可以用倍增的思想,最多跳\(log_{dep}\)次
于是我们有
for(int i=1;i<=n;i++)
for(int j=1;j<=19;j++)
f[i][j]=f[f[i][j-1]][j-1];
然后我们去找符合条件的每个点的祖先,注意我们循环的时候,是保证了这个点也是可以对最远的祖先产生贡献的,那么我们点差分的时候就是对\(val[fa[lca[i]]]--,val[i]++\),相当于在这条路径上都加了1
最后我们求答案的时候,再\(dfs\)一遍就行了
分步讲解
1.处理倍增数组和祖先数组.一次\(dfs\)求出信息
void dfs1(int x,int father)
{
f[x][0]=father;
fa[x]=father;
for(int i=head[x];i!=-1;i=ne[i])
{
int j=ver[i];
if(j==father) continue;
dist[j]=dist[x]+e[i];
dfs1(j,x);
}
}
2.找符合条件的最远祖先
int getfa(int x,int st)
{
for(int i=19;i>=0;i--)
{
if(f[x][i]!=0&&(dist[st]-dist[f[x][i]])<=L)
{
x=f[x][i];
}
}
return x;
}
3.树上差分
for(int i=1;i<=n;i++)
{
int top=getfa(i,i);
val[fa[top]]--;
val[i]++;
}
4.统计答案
因为是算子树,所以我们需要先递归到底,然后在回溯的时候,累加答案
int getfa(int x,int st)
{
for(int i=19;i>=0;i--)
{
if(f[x][i]!=0&&(dist[st]-dist[f[x][i]])<=L)
{
x=f[x][i];
}
}
return x;
}
完整代码
#include<bits/stdc++.h>
using namespace std;
const int N=200100;
int f[200055][20];
int ne[N],head[N],ver[N];
long long e[N];
int idx;
long long n,L;
int fa[N];
long long dist[N];
long long val[N];
long long ans[N];
void add(int u,int v,long long w)
{
ne[idx]=head[u];
ver[idx]=v;
e[idx]=w;
head[u]=idx;
idx++;
}
void dfs1(int x,int father)
{
f[x][0]=father;
fa[x]=father;
for(int i=head[x];i!=-1;i=ne[i])
{
int j=ver[i];
if(j==father) continue;
dist[j]=dist[x]+e[i];
dfs1(j,x);
}
}
int getfa(int x,int st)
{
for(int i=19;i>=0;i--)
{
if(f[x][i]!=0&&(dist[st]-dist[f[x][i]])<=L)
{
x=f[x][i];
}
}
return x;
}
void dfs2(int x,int father)
{
ans[x]=val[x];
for(int i=head[x];i!=-1;i=ne[i])
{
int j=ver[i];
if(j==father) continue;
dfs2(j,x);
ans[x]+=ans[j];
}
}
int main()
{
memset(head,-1,sizeof(head));
scanf("%lld%lld",&n,&L);
for(int i=2;i<=n;i++)
{
int u;
long long w;
scanf("%d%lld",&u,&w);
add(u,i,w);
}
dfs1(1,0);
for(int j=1;j<=n;j++)
for(int i=1;i<=19;i++)
{
f[j][i]=f[f[j][i-1]][i-1];
}
for(int i=1;i<=n;i++)
{
int top=getfa(i,i);
val[fa[top]]--;
val[i]++;
}
dfs2(1,0);
for(int i=1;i<=n;i++)
printf("%d\n",ans[i]);
return 0;
}