UOJ #150 【NOIP2015】 运输计划
题目描述
公元 \(2044\) 年,人类进入了宇宙纪元。
\(L\) 国有 \(n\) 个星球,还有 \(n-1\) 条双向航道,每条航道建立在两个星球之间,这 \(n-1\) 条航道连通了 \(L\) 国的所有星球。
小 \(P\) 掌管一家物流公司, 该公司有很多个运输计划,每个运输计划形如:有一艘物流飞船需要从 \(u_i\) 号星球沿最快的宇航路径飞行到 $v_i$ 号星球去。显然,飞船驶过一条航道是需要时间的,对于航道 $j$,任意飞船驶过它所花费的时间为 $t_j$,并且任意两艘飞船之间不会产生任何干扰。
为了鼓励科技创新, $L$ 国国王同意小 $P$ 的物流公司参与 $L$ 国的航道建设,即允许小$P$ 把某一条航道改造成虫洞,飞船驶过虫洞不消耗时间。
在虫洞的建设完成前小 $P$ 的物流公司就预接了 $m$ 个运输计划。在虫洞建设完成后,这 $m$ 个运输计划会同时开始,所有飞船一起出发。当这 $m$ 个运输计划都完成时,小 $P$ 的物流公司的阶段性工作就完成了。
如果小 $P$ 可以自由选择将哪一条航道改造成虫洞, 试求出小 $P$ 的物流公司完成阶段性工作所需要的最短时间是多少?
输入格式
第一行包括两个正整数 $n,m$,表示 L 国中星球的数量及小 P 公司预接的运输计划的数量,星球从 $1$ 到 $n$ 编号。
接下来 $n-1$ 行描述航道的建设情况,其中第 $i$ 行包含三个整数 $a_i,b_i$ 和 $t_i$,表示第 $i$ 条双向航道修建在 $a_i$ 与 $b_i$ 两个星球之间,任意飞船驶过它所花费的时间为 $t_i$。数据保证 $1 \leq a_i,b_i \leq n$且 $0 \leq t_i \leq 1000$。
接下来 $m$ 行描述运输计划的情况,其中第 $j$ 行包含两个正整数 $u_j$ 和$v_j$,表示第 $j$ 个运输计划是从 $u_j$ 号星球飞往 $v_j$号星球。数据保证 $1 \leq u_i,v_i \leq n$
输出格式
输出文件只包含一个整数,表示小 $P$ 的物流公司完成阶段性工作所需要的最短时间。
还记得当初在$noip$考场上,不会树剖不会二分答案,于是对于这道题就是狂跑lca啊lca……
后来学了各种东西之后,就自己打了一个复杂度为$O(n \log ^2 n)$的算法,大意如下:
先对于整棵树进行树链剖分,然后考虑二分一个答案(因为题目所求是最大值最小,所以答案单调),只需判断这个答案是否可行。于是我们需要把长度大于$x$的路径扫一遍,求一下这些路径的交,从交中找出一条权值最大的边,把这条边权值变为$0$(显然这样最优而且并不需要真的赋值为$0$),判断一下最长路径现在是否小于等于$x$即可。
接着,发现被卡了。在UOJ上只有97分。于是,接下来就是奇技淫巧时间。
有一次有点无聊,于是把树链剖分中的求重儿子部分的代码中的小于号改为了小于等于号,就这么AC了……汗……
下面是代码:
#include<iostream> #include<cstdio> #include<cstring> #define INF 2147483647 #define maxn 300001 using namespace std; struct data{ int f,t,v; }he[maxn]; int fa[maxn],head[maxn*2],to[maxn*2],next[maxn*2],c[maxn*2]; int dep[maxn],son[maxn],top[maxn],siz[maxn],w[maxn],zui; int tc[maxn],fc1[maxn],fc[maxn],cha[maxn+1],tt,m,n,x,y,l,r; int getint(){ int w=0,q=0; char c=getchar(); while((c>'9'||c<'0')&&c!='-') c=getchar(); if(c=='-') q=1,c=getchar(); while(c>='0'&&c<='9') w=w*10+c-'0',c=getchar(); return q?-w:w; } void dfs1(int u,int dd){ dep[u]=dd;siz[u]++; for(int i=head[u];i;i=next[i]) if(!dep[to[i]]){ dfs1(to[i],dd+1);siz[u]+=siz[to[i]]; fa[to[i]]=u;fc[to[i]]=c[i]; if(siz[to[i]]>=siz[son[u]]) son[u]=to[i]; } } void dfs2(int u,int dd){ w[u]=++tt;top[u]=dd; fc1[tt]=fc[u]; if(son[u]) tc[son[u]]=tc[u]+fc[son[u]],dfs2(son[u],dd); for(int i=head[u];i;i=next[i]) if(to[i]!=son[u]&&to[i]!=fa[u]) dfs2(to[i],to[i]); } inline int lca(int x,int y){ int ju=0; while(top[x]!=top[y]){ int a=fa[top[x]],b=fa[top[y]]; if(dep[a]<dep[b]) swap(x,y),swap(a,b); cha[w[x]+1]--;cha[w[top[x]]]++; ju+=tc[x]+fc[top[x]];x=a; } if(dep[x]<dep[y]) swap(x,y); cha[w[x]+1]--;cha[w[son[y]]]++; ju+=tc[x]-tc[y]; return ju; } inline bool pd(int kk){ int ci=0; memset(cha,0,sizeof(cha)); for(register int i=1;i<=m;i++) if(he[i].v>kk) lca(he[i].f,he[i].t),ci++; int hehe=0,hh=c[0]; for(register int i=1;i<=tt;i++){ hh+=cha[i]; if(hh==ci) hehe=max(hehe,fc1[i]); } if(zui-hehe<=kk) return 1; return 0; } int main(){ n=getint();m=getint(); for(int i=1;i<n;i++){ x=getint();y=getint(); to[++tt]=y,next[tt]=head[x],head[x]=tt; to[++tt]=x,next[tt]=head[y],head[y]=tt; c[tt-1]=c[tt]=getint(); } dfs1(1,1);tt=0;dfs2(1,1); for(int i=1;i<=m;i++){ he[i].f=getint();he[i].t=getint(); he[i].v=lca(he[i].f,he[i].t); r=max(r,he[i].v); } zui=r++; while(l!=r){ int mid=(l+r)>>1; if(pd(mid))r=mid; else l=mid+1; } printf("%d",l); return 0; }
但是后来,我发现其实这个算法是可以优化到$O(n \log n)$的。设当前二分的答案为$x$,由于每一次最长路径长度小于等于$x$时显然可以,否则不满足的所有路径的交必定在最长路径上(交为空时其实也是最长路径的子集),而且是连续的一段。然后我们把最长链抠出来,对于其他每一条路径与最长路径求一下交,那么路径的交就变成区间上的问题了,可以线性地做。于是复杂度降为$O(n \log n)$。
还有怎么两条路径求交的问题。由于我们只需求出两条路径交的两个端点,可以把这两条路径的共4个点lca一下,分类讨论即可。
下面贴代码:
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #include<cmath> #define File(s) freopen(s".in","r",stdin),freopen(s".out","w",stdout) #define maxn 300010 using namespace std; typedef long long llg; struct data{ int u,v,g,x; }s[maxn]; int n,m,dep[maxn],son[maxn],siz[maxn],fa[maxn],tc[maxn],top[maxn],fc[maxn]; int head[maxn],to[maxn<<1],next[maxn<<1],c[maxn<<1],tt,K,nn,now; int dc[maxn],cd[maxn],ld,b[maxn],lb,bb[maxn],ca[maxn]; int getint(){ int w=0;bool q=0; char c=getchar(); while((c>'9'||c<'0')&&c!='-') c=getchar(); if(c=='-') c=getchar(),q=1; while(c>='0'&&c<='9') w=w*10+c-'0',c=getchar(); return q?-w:w; } void link(int x,int y){ to[++tt]=y;next[tt]=head[x];head[x]=tt; to[++tt]=x,next[tt]=head[y];head[y]=tt; c[tt]=c[tt-1]=getint(); } void dfs(int u){ siz[u]=1; for(int i=head[u],v;v=to[i],i;i=next[i]) if(!siz[v]){ dep[v]=dep[u]+1; fa[v]=u; fc[v]=c[i]; dfs(v); siz[u]+=siz[v]; if(siz[v]>siz[son[u]]) son[u]=v; } } void dfs(int u,int d){ top[u]=d; if(son[u]) tc[son[u]]=tc[u]+fc[son[u]],dfs(son[u],d); for(int i=head[u],v;v=to[i],i;i=next[i]) if(!top[v]) dfs(v,v); } inline int lca(int u,int v){//树链剖分求lca now=0; while(top[u]!=top[v]){ if(dep[fa[top[u]]]<dep[fa[top[v]]]) swap(u,v); now+=tc[u]+fc[top[u]]; u=fa[top[u]]; } if(dep[u]>dep[v]) swap(u,v); now+=tc[v]-tc[u]; return u; } void kou(int u,int v,int g){//把最长路径抠出来 while(u!=g) cd[++ld]=u,ca[ld]=fc[u],u=fa[u]; cd[++ld]=g; //从u到g的路径 while(v!=g) b[++lb]=v,bb[lb]=fc[v],v=fa[v]; //从v到g的路径 while(lb) ca[ld]=fc[b[lb]],cd[++ld]=b[lb--]; //ca表示长度 for(int i=1;i<=ld;i++) dc[cd[i]]=i; } inline bool pd(int x){//判断是否可行 int l=1,r=ld,gi=0; for(register int i=1;i<=m;i++) if(s[i].x>x){ l=max(l,s[i].u); r=min(r,s[i].v); } for(int i=l;i<r;i++) gi=max(gi,ca[i]); return nn-gi<=x; } inline void gi(int &xx,int u){ int x,y; x=lca(u,s[K].u); if(dep[x]>dep[s[K].g]){xx=dc[x];return;} y=lca(u,s[K].v); if(dep[y]>dep[s[K].g]){xx=dc[y];return;} xx=dc[s[K].g];//点在路径之外,要么另一个点的两个lca在最长路径上,此时视为最长路径的最上点,要么两条路径交为空 } int main(){ n=getint(); m=getint(); for(int i=1;i<n;i++) link(getint(),getint()); dep[1]=1;dfs(1); dfs(1,1); for(register int i=1;i<=m;i++){ s[i].u=getint(); s[i].v=getint(); s[i].g=lca(s[i].u,s[i].v); s[i].x=now; if(now>nn) nn=now,K=i;//nn为最长路径长度,K为最长路径标号 } kou(s[K].u,s[K].v,s[K].g); for(register int i=1,l,r;i<=m;i++) if(i!=K){ gi(l,s[i].u); gi(r,s[i].v);//lca分类讨论求路径交 if(l>r) swap(l,r); s[i].u=l,s[i].v=r;//注意这里的区间为[l,r) } s[K].u=1,s[K].v=ld; int l=0,r=nn,mid; while(l!=r){ mid=l+r>>1; if(pd(mid)) r=mid; else l=mid+1; } printf("%d",l); }
又或者可以直接优化查分,使得差分变为$O(1)$,具体思路如下:把一条路径上的边全部加$1$,那么可以把两个端点的值加$1$,lca的值减$2$,全部完成后dfs一遍解决。但不知道为什么,速度比第一种算法还要慢,大概是常数有点大吧……
下面贴代码:
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #include<cmath> #define File(s) freopen(s".in","r",stdin),freopen(s".out","w",stdout) #define maxn 300010 using namespace std; typedef long long llg; struct data{ int u,v,g,x; }s[maxn]; int n,m,dep[maxn],son[maxn],siz[maxn],fa[maxn],tc[maxn],top[maxn],fc[maxn]; int head[maxn],to[maxn<<1],next[maxn<<1],c[maxn<<1],tt,now,nn; int ch[maxn],ci,_m; int getint(){ int w=0;bool q=0; char c=getchar(); while((c>'9'||c<'0')&&c!='-') c=getchar(); if(c=='-') c=getchar(),q=1; while(c>='0'&&c<='9') w=w*10+c-'0',c=getchar(); return q?-w:w; } void link(int x,int y){ to[++tt]=y;next[tt]=head[x];head[x]=tt; to[++tt]=x,next[tt]=head[y];head[y]=tt; c[tt]=c[tt-1]=getint(); } void dfs(int u){ siz[u]=1; for(int i=head[u],v;v=to[i],i;i=next[i]) if(!siz[v]){ dep[v]=dep[u]+1; fa[v]=u; fc[v]=c[i]; dfs(v); siz[u]+=siz[v]; if(siz[v]>siz[son[u]]) son[u]=v; } } void dfs(int u,int d){ top[u]=d; if(son[u]) tc[son[u]]=tc[u]+fc[son[u]],dfs(son[u],d); for(int i=head[u],v;v=to[i],i;i=next[i]) if(!top[v]) dfs(v,v); } int lca(int u,int v){ now=0; while(top[u]!=top[v]){ if(dep[fa[top[u]]]<dep[fa[top[v]]]) swap(u,v); now+=tc[u]+fc[top[u]]; u=fa[top[u]]; } if(dep[u]>dep[v]) swap(u,v); now+=tc[v]-tc[u]; return u; } void work(int u){ for(int i=head[u],v;v=to[i],i;i=next[i]) if(v!=fa[u]) work(v),ch[u]+=ch[v],ch[v]=0; if(ch[u]==ci) _m=max(_m,fc[u]); } bool pd(int x){ ci=_m=ch[1]=0; for(int i=1;i<=m;i++) if(s[i].x>x){ ch[s[i].u]++;ch[s[i].v]++; ch[s[i].g]-=2; ci++; } work(1); return nn-_m<=x; } int main(){ File("a"); n=getint(); m=getint(); for(int i=1;i<n;i++) link(getint(),getint()); dep[1]=1;dfs(1); dfs(1,1); for(int i=1;i<=m;i++){ s[i].u=getint(); s[i].v=getint(); s[i].g=lca(s[i].u,s[i].v); nn=max(nn,s[i].x=now); } int l=0,r=nn,mid; while(l!=r){ mid=l+r>>1; if(pd(mid)) r=mid; else l=mid+1; } printf("%d",l); }
不过这道题的标解的复杂度好像是并查集复杂度……留坑待填……