bzoj4326: NOIP2015 运输计划(二分+LCA+树上差分)
题目链接:https://www.lydsy.com/JudgeOnline/problem.php?id=4326
题目大意:有一颗含有n个顶点的树,每两个点之间有一个边权,现在有m个运输计划,每个运输计划包含u和v,一个运输计划的代价为u到v的最短距离,现在可以使一条边的权值为0,求出使所以计划中代价最大值的最小值。
解题思路:最大值的最小值很明显需要二分答案,首先求出所以运输计划消耗的代价,然后二分答案mid,对于消耗的的代价大于mid的运输计划,假设有m条,我们只能使一条边的权值变为0,所以这条边必定是这m个运输计划的路径交,转化成一个求路径交的问题了,这里我们可以用树上差分来做,再开两个数组:tmp和prev。tmp用来记录点的出现次数(具体点说实际上记录的是点到其父亲的边的出现次数),prev记录每个点到其父亲的那条边权值。对于一条起点s,终点t的路径。我们这样处理:tmp[s]++,tmp[t]++,tmp[LCA(s,t)]-=2。最后要从所有叶结点把权值向上累加用一遍dfs就可以了。tmp[s]++,推到根后,根到s的tmp值均加了1,同理tmp[t]++,从根到t的tmp值也都加了1,而tmp[LCA(s,t)]-=2上推使得从根到LCA(s,t)的tmp值又变回原来的值,就能把多余的路径都减掉。对于多次操作,我只需要维护tmp的值,最后一次性上推即可。如果tmp[i]等于路径的个数,即代表i到其父亲的那条边是所有路径的交。
代码:
#include<iostream> #include<cstdio> #include<cstring> using namespace std; typedef long long ll; const int maxn=300007; int n,m,dist[maxn],tot,depth[maxn],head[maxn],fa[maxn][25],tmp[maxn],prev[maxn]; struct Edge{ int v,next,w; }edge[maxn*2]; void add(int u,int v,int w){ edge[tot].v=v; edge[tot].w=w; edge[tot].next=head[u]; head[u]=tot++; } struct Query{ int u,v,dis,lca; }q[maxn]; void dfs(int u,int pre){ depth[u]=depth[pre]+1; fa[u][0]=pre; for(int i=1;i<=20;i++) fa[u][i]=fa[fa[u][i-1]][i-1]; for(int i=head[u];i!=-1;i=edge[i].next){ int v=edge[i].v; if(v==pre){ prev[u]=edge[i].w; continue; } dist[v]=dist[u]+edge[i].w; if(v!=pre) dfs(v,u); } } int LCA(int u,int v){ if(depth[u]<depth[v]) swap(u,v); for(int i=20;i>=0;i--){ if(depth[u]-(1<<i)>=depth[v]) u=fa[u][i]; } if(u==v) return u; for(int i=20;i>=0;i--){ if(fa[u][i]!=fa[v][i]) u=fa[u][i],v=fa[v][i]; } return fa[u][0]; } void dfs1(int u,int pre){ for(int i=head[u];i!=-1;i=edge[i].next){ int v=edge[i].v; if(v==pre) continue; dfs1(v,u); tmp[u]+=tmp[v]; } } bool check(int x){ memset(tmp,0,sizeof(tmp)); int cnt=0; int Maxx=0; for(int i=1;i<=m;i++){ if(q[i].dis>x){ tmp[q[i].u]++; tmp[q[i].v]++; tmp[q[i].lca]-=2; Maxx=max(Maxx,q[i].dis); cnt++; } } if(cnt==0)return true; dfs1(1,0); for(int i=2;i<=n;i++){ if(tmp[i]==cnt&&Maxx-prev[i]<=x) return true; } return false; } int main(){ memset(head,-1,sizeof(head)); scanf("%d%d",&n,&m); for(int i=1;i<n;i++){ int u,v,w; scanf("%d%d%d",&u,&v,&w); add(u,v,w); add(v,u,w); } depth[0]=-1; dfs(1,0); for(int i=1;i<=m;i++){ scanf("%d%d",&q[i].u,&q[i].v); q[i].lca=LCA(q[i].u,q[i].v); q[i].dis=dist[q[i].u]+dist[q[i].v]-2*dist[q[i].lca]; // cout<<q[i].dis<<endl; // cout<<dist[q[i].u]<<" "<<dist[q[i].v]<<" "<<dist[lca(q[i].u,q[i].v)]<<endl; } int l=0,r=1e9; int ans; while(l<=r){ int mid=(l+r)/2; if(check(mid)){ ans=mid; r=mid-1; }else l=mid+1; } printf("%d\n",ans); return 0; }