【NOIP2015】运输计划(二分+lca+树上差分)
不得不承认noip的题出的是真的好
n=300000,m=300000的极限数据不由得想到某种nlogn的做法
这道题乍一看和二分没有一点关系,然而我们仔细想想后发现,对于一个时间t1,如果t1之内可以完成,那么t2肯定也能完成!
满足单调性,因此我们可以二分时间,那么如何check呢?
首先,对于一个t,那么所有m必定分为两类——耗时超过t,耗时小于等于t的。
由于我们只有一次变虫洞的机会。如果有多个耗时超过t的链,那么我们变虫洞的地方只能在所有链经过的公共部分,因为只有这样才能尽量使最大耗时小于t。
那么如何去找经过的公共部分呢?暴力的思路是对于每个超过t的链的每条边都去标记一下(找到lca,然后暴力往上跳,复杂度O(log(n)mn)),然后扫一遍。
很明显不行,于是又仔细一想,自然而然想到了树上差分。
所以,对于每个超出的链,我们只需要diff[from]++,diff[to]++,diff[lca]-=2,最后扫一遍统计即可。
#include<bits/stdc++.h>
#define N 300005
#define M 300005
using namespace std;
int n,m,first[N],tot;
int dis[N],up[N][25],depth[N],d[M],pre[N];
int from[M],end[M],l[M],maxlen;
int diff[N],cnt,ret;
struct node
{
int to,next,val;
}edge[2*N];
inline void addedge(int x,int y,int z)
{
tot++;
edge[tot].to=y;
edge[tot].val=z;
edge[tot].next=first[x];
first[x]=tot;
}
inline void dfs1(int now,int fa)
{
depth[now]=depth[fa]+1;
up[now][0]=fa;
for(int i=1;i<=24;i++) up[now][i]=up[up[now][i-1]][i-1];
for(int u=first[now];u;u=edge[u].next)
{
int vis=edge[u].to;
if(vis==fa) continue;
dis[vis]=dis[now]+edge[u].val;
pre[vis]=edge[u].val;
dfs1(vis,now);
}
}
inline int lca(int x,int y)
{
if(depth[x]<depth[y]) swap(x,y);//默认x比y深
for(int i=24;i>=0;i--) if(depth[up[x][i]]>=depth[y]) x=up[x][i];
if(x==y) return x;
for(int i=24;i>=0;i--)
if(up[x][i]!=up[y][i])
{
x=up[x][i];
y=up[y][i];
}
return up[x][0];
}
inline void dfs2(int now)
{
for(int u=first[now];u;u=edge[u].next)
{
int vis=edge[u].to;
if(vis==up[now][0]) continue;
dfs2(vis);
diff[now]+=diff[vis];
}
if(diff[now]==cnt) ret=max(ret,pre[now]); //如果有多条链都超过了要求 那虫洞只能建立在这些链都经过的边上 这里贪心找
}
inline bool check(int t)
{
memset(diff,0,sizeof(diff));
ret=cnt=0;
for(int i=1;i<=m;i++)
{
if(d[i]>t) {//把所有超过要求的链统计一下
diff[from[i]]++;
diff[end[i]]++;
diff[l[i]]-=2;
cnt++;
}
}
dfs2(1);
if(maxlen-ret>t) return false;
return true;
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(NULL),cout.tie(NULL);
cin>>n>>m;
for(int i=1;i<=n-1;i++)
{
int x,y,z;
cin>>x>>y>>z;
addedge(x,y,z);
addedge(y,x,z);
}
dfs1(1,0);
for(int i=1;i<=m;i++)
{
cin>>from[i]>>end[i];
l[i]=lca(from[i],end[i]);
d[i]=dis[from[i]]+dis[end[i]]-2*dis[l[i]];
maxlen=max(maxlen,d[i]);
}
int l=0,r=maxlen;
while(l<r)
{
int m=(l+r)>>1;
if(check(m)) r=m;
else l=m+1;
}
cout<<l;
return 0;
}
QQ40523591~欢迎一起学习交流~