NOIp 2015 运输计划
运输计划
问题大意
有n个星球与n-1条双向边,每条边有时间ti,有m个从vi到ui的运输计划。允许你将一条边的时间降为0。同时开始所有的计划,问最小要多少时间完成计划。
输入输出格式
输入格式:
输入文件名为 transport.in。
第一行包括两个正整数 n、m,表示 L 国中星球的数量及小 P 公司预接的运输计划的数量,星球从 1 到 n 编号。
接下来 n-1 行描述航道的建设情况,其中第 i 行包含三个整数 ai, bi 和 ti,表示第
i 条双向航道修建在 ai 与 bi 两个星球之间,任意飞船驶过它所花费的时间为 ti。
接下来 m 行描述运输计划的情况,其中第 j 行包含两个正整数 uj 和 vj,表示第 j个 运输计划是从 uj 号星球飞往 vj 号星球。
输出格式:
输出 共1行,包含1个整数,表示小P的物流公司完成阶段性工作所需要的最短时间。
样例输入输出
input
6 3
1 2 3
1 6 4
3 1 7
4 3 6
3 5 5
3 6
2 5
4 5
output
11
数据规模
n<=300000 m<=300000 1<=a,b,u,v<=n 0<=ti<=1000
解题报告
这一题一眼看出来就是二分答案。关键是怎么检查。这里我们可以用一个小小的贪心。如果枚举的t大与某些计划原本需要的时间,我们可以记下这些计划,如果要去掉一个边的时间,我们肯定要去掉这些超时计划的所有路径的公共边,只有是公共边,才有同时让这些计划不超时的可能。如果去掉一条公共边,能使得耗时最长的计划不超时,就满足要求。
那么,如何找公共边呢?我们可以用树的差分来实现。简单说一下,首先开一个数组tmp,如果有一条路径从vi到ui,那么我们就tmp[vi]++,tmp[ui++],tmp[lca(vi,ui)]-=2;之后,就用一遍dfs,将tmp[]数组进行dfs序的累加。如果有一个tmp[i]的值等于超时点的数量,那么从i到i的父亲的那一条边就算公共边(证明就不说啦,不难)。
#include <cstdio> #include <iostream> #include <cmath> #include <queue> #include <algorithm> #include <cstring> #include <climits> #define MAXN 300000+10 #define MAXM 300000+10 using namespace std; int n,m,head[MAXN],deep[MAXN],dis[MAXN][50],fa[MAXN][50],num; int p[MAXM],pa[MAXN],tmp[MAXN],root=1,s[MAXN],t[MAXN],all; bool vis[MAXN]; struct Edge{ int dis,from,to,next; }edge[MAXN<<1]; void add(int from,int to,int dis) { edge[++num].next=head[from]; edge[num].from=from; edge[num].to=to; edge[num].dis=dis; head[from]=num; } void dfs(int x) { for(int i=head[x];i;i=edge[i].next) if(!deep[edge[i].to]) { deep[edge[i].to]=deep[x]+1; fa[edge[i].to][0]=x; dis[edge[i].to][0]=edge[i].dis; dfs(edge[i].to); } } void dfs2(int x) { int ans=0; for(int i=head[x];i;i=edge[i].next) if(!vis[edge[i].to]) { vis[edge[i].to]=1; dfs2(edge[i].to); ans+=tmp[edge[i].to]; } tmp[x]+=ans; } void init() { deep[root]=1; dfs(root); for(int j=0;(1<<j)<=n;j++) for(int i=1;i<=n;i++) if(fa[i][j-1]) fa[i][j]=fa[fa[i][j-1]][j-1], dis[i][j]=dis[i][j-1]+dis[fa[i][j-1]][j-1]; } int lca(int a,int b,int no) { if(deep[a]>deep[b]) swap(a,b); int d=deep[b]-deep[a],ans=0; for(int j=0;(1<<j)<=d;j++) if((1<<j)&d) { ans+=dis[b][j]; b=fa[b][j]; } if(a==b) {pa[no]=a;return ans;} for(int j=log2(n);j>=0;j--) if(fa[a][j]!=fa[b][j]) { ans+=dis[a][j]+dis[b][j]; a=fa[a][j];b=fa[b][j]; } pa[no]=fa[a][0]; return ans+dis[a][0]+dis[b][0]; } bool check(int ti) { memset(tmp,0,sizeof tmp); int num=0; memset(vis,0,sizeof vis); int maxx=0,maxp=0,maxm=0; for(int i=1;i<=m;i++) { if(p[i]>ti) { num++; tmp[s[i]]++;tmp[t[i]]++; tmp[pa[i]]-=2; maxm=max(maxm,p[i]-ti); } } vis[1]=1; dfs2(1); for(int i=1;i<=n;i++) if(tmp[i]==num&&dis[i][0]>=maxm) return 1; return 0; } void er() { int l=0,r=all,m; while(l<r-1) { m=(l+r)>>1; if(check(m)) r=m; else l=m; } if(check(l)) printf("%d\n",l); else printf("%d\n",r); } int read(){ int in=0;char ch=getchar(); for(;ch<'0'||ch>'9';ch=getchar()); for(;ch>='0'&&ch<='9';ch=getchar())in=in*10+ch-'0'; return in; } int main() { n=read();m=read(); int x,y,z; for(int i=1;i<n;i++) { x=read();y=read();z=read(); all+=z; add(x,y,z); add(y,x,z); } init(); for(int i=1;i<=m;i++) { //scanf("%d%d",&s[i],&t[i]); s[i]=read();t[i]=read(); p[i]=lca(s[i],t[i],i); } er(); return 0; }