BZOJ - 2500 树形DP乱搞
题意:给出一棵树,两个给给的人在第\(i\)天会从节点\(i\)沿着最长路径走,求最长的连续天数\([L,R]\)使得\([L,R]\)为起点的最长路径极差不超过m
求\(1\)到\(n\)的最长路经可用树形DP求解,
设\(f[i]\):\(i\)的子树下到\(i\)的最远距离
\(g[i]\):\(i\)子树下除了\(f[i]\)子树以外的最远距离
\(h[i]\):除了\(i\)子树以外到\(i\)的最远距离
\(h[i]\)从父到儿子的转移需要判断\(i\)到底是\(fa\)的最远距离所在边还是次远距离所在边(可相等),还有直接来自父亲以上\(h[fa]\)的转移
搞完后求极差就用二分+RMQ强行求出来,注意初始化需要f和h的对比
题目简单但要细心
#include<bits/stdc++.h>
#define rep(i,j,k) for(register int i=j;i<=k;i++)
#define rrep(i,j,k) for(register int i=j;i>=k;i--)
#define erep(i,u) for(register int i=head[u];~i;i=nxt[i])
#define fastIO ios::sync_with_stdio(0);cin.tie(0);cout.tie(0)
#define println(x) printf("%lld\n",(ll)(x))
using namespace std;
typedef long long ll;
const int MAXN = 1e6+11;
const int MOD = 142857;
const int INF = 1<<30;
int to[MAXN<<1],nxt[MAXN<<1],head[MAXN],tot;
int cost[MAXN<<1];
int n,m;
void init(int n){memset(head,-1,(n+2)*sizeof(int)),tot=0;}
void add(int u,int v,ll w){
to[tot]=v;
cost[tot]=w;
nxt[tot]=head[u];
head[u]=tot++;
}
int f[MAXN],g[MAXN],h[MAXN];
int mx[MAXN][22],mn[MAXN][22];
void DP0(int u,int fa){
f[u]=g[u]=h[u]=0;
for(int i=head[u];~i;i=nxt[i]){
int v=to[i]; ll w=cost[i];
if(v==fa) continue;
DP0(v,u);
if(f[v]+w>f[u]){
g[u]=f[u]; //次长子树
f[u]=f[v]+w; //最长子树
}else if(f[v]+w>g[u]){
g[u]=f[v]+w;
}
}
}
void DP1(int u,int fa){
for(int i=head[u];~i;i=nxt[i]){
int v=to[i]; ll w=cost[i];
if(v==fa) continue;
if(f[u]-w==f[v]) h[v]=max(h[u]+w,g[u]+w);//本身v作为儿子是f[u]的最大值,那就从u的次大子树中转移
else h[v]=max(h[u]+w,f[u]+w);
DP1(v,u);
}
}
ll C(int lo,int hi){
int k=log2(hi-lo+1);
return max(mx[lo][k],mx[hi-(1<<k)+1][k])
-min(mn[lo][k],mn[hi-(1<<k)+1][k]);
}
int gao(int st){
int lo=st,hi=n;
while(lo<hi){
int mid=lo+(hi-lo+1)/2;
if(C(st,mid)<=m) lo=mid;
else hi=mid-1;
}
return C(st,lo)?lo:lo-1;
}
int main(){
#ifndef ONLINE_JUDGE
freopen("stdin.txt","r",stdin);
#endif
while(~scanf("%d%d",&n,&m)){
init(n);
for(int i=2;i<=n;i++){
int fi;ll di;
scanf("%d%lld",&fi,&di);
add(i,fi,di);
add(fi,i,di);
}
DP0(1,-1);
DP1(1,-1);
for(int i=1;i<=n;i++){
mx[i][0]=mn[i][0]=max(f[i],h[i]);//f[]只考虑子树内,h[]只考虑子树外
}
int t=log2(n);
for(int i=1;i<=t;i++){
for(int j=1;j<=n;j++){
mx[j][i]=max(mx[j][i-1],mx[j+(1<<i-1)][i-1]);
mn[j][i]=min(mn[j][i-1],mn[j+(1<<i-1)][i-1]);
}
}
int ans=0;
for(int i=1;i<=n;i++){
int hi=gao(i);
ans=max(ans,hi-i+1);
}
println(ans);
}
return 0;
}