poj3162(树形dp+线段树)
题目链接:
http://poj.org/problem?id=3162
题意:n个结点构成一棵树 ,MC将在n天,依次按结点编号设为起点,选取距离起点最远的结点作为终点,得到最远距离。问:找到一个区间,使得这个区间里最大最小值的差距不超过m,求区间的最大长度。
解题思路:
求每天的最远距离很明显是树形dp的问题,求的n个值后,我们可以用线段树来存储这些值。每次维护区间的l,r,用线段树求的该区间的最大最小值,然后判断该区间是否符合要求。可以把l,r初始值都设为1,然后往右移动即可。
#include<iostream> #include<stdio.h> #include<string.h> #define inf 0x3f3f3f3f using namespace std; const int mod=1e9+7; const int maxn=1e6+5; typedef long long ll; int cnt; int head[maxn]; struct st{ int to,next,w; }stm[maxn*2]; void add(int u,int v,int w){ stm[cnt].to=v; stm[cnt].next=head[u]; stm[cnt].w=w; head[u]=cnt++; } int dis[maxn]; int n,m; int dp[maxn][3]; void dfs1(int now,int fa){ dp[now][0]=0;//子树方向最远 dp[now][1]=0;//子树方向次远 for(int i=head[now];~i;i=stm[i].next){ int w=stm[i].w; int to=stm[i].to; if(to==fa)continue; dfs1(to,now); if(dp[to][0]+w>dp[now][0]){ dp[now][1]=dp[now][0]; dp[now][0]=dp[to][0]+w; } else if(dp[to][0]+w>dp[now][1]){ dp[now][1]=dp[to][0]+w; } } } void dfs2(int now,int fa){ for(int i=head[now];~i;i=stm[i].next){ int w=stm[i].w; int to=stm[i].to; if(to==fa)continue; if(dp[now][0]>dp[to][0]+w){ dp[to][2]=max(dp[now][0],dp[now][2])+w;//父节点方向最远 } else { dp[to][2]=max(dp[now][1],dp[now][2])+w; } dfs2(to,now); } } int tre[maxn*4][2]; void pushup(int rt){ tre[rt][0]=max(tre[rt<<1][0],tre[rt<<1|1][0]); tre[rt][1]=min(tre[rt<<1][1],tre[rt<<1|1][1]); } void build(int l,int r,int rt){ if(l==r){ tre[rt][0]=tre[rt][1]=dis[l]; return ; } int mid=(l+r)/2; build(l,mid,rt*2); build(mid+1,r,rt*2+1); pushup(rt); } int query0(int lm,int rm,int l,int r,int rt){ if(lm<=l&&rm>=r){ return tre[rt][0]; } int mid=(l+r)/2; int ans=0; if(mid>=lm){ ans=max(ans,query0(lm,rm,l,mid,rt*2)); } if(mid<rm){ ans=max(ans,query0(lm,rm,mid+1,r,rt*2+1)); } return ans; } int query1(int lm,int rm,int l,int r,int rt){ if(lm<=l&&rm>=r){ return tre[rt][1]; } int mid=(l+r)/2; int ans=inf; if(mid>=lm){ ans=min(ans,query1(lm,rm,l,mid,rt*2)); } if(mid<rm){ ans=min(ans,query1(lm,rm,mid+1,r,rt*2+1)); } return ans; } int main(){ int u,w; cnt=0; memset(head,-1,sizeof(head)); scanf("%d%d",&n,&m); for(int i=2;i<=n;i++){ scanf("%d%d",&u,&w); add(i,u,w); add(u,i,w); } dfs1(1,-1); dfs2(1,-1); for(int i=1;i<=n;i++){ dis[i]=max(dp[i][0],dp[i][2]); } build(1,n,1); int ans=0; int l,r; l=r=1; while((l+r)<n*2){ int maxs=query0(l,r,1,n,1); int mins=query1(l,r,1,n,1); /* cout<<l<<" "<<r<<endl; cout<<endl; cout<<maxs<<" "<<mins<<endl;*/ if(mins+m>=maxs){ ans=max(ans,r-l+1); r++; if(r>n)break; } else{ l++; } } cout<<ans; return 0; }