运输计划 二分答案 树上差分
运输计划 二分答案 树上差分
附赠样例图:
首先容易想到二分答案,将求解问题转换为判定,并且注意到答案显然具有单调性。
然后我们考虑如何快速判定。先预处理出所有计划路径长(\(dis[i]\)表示根节点到节点\(i\)的距离,\(dis[u]+dis[v]-2\times dis[lca(u,v)]\)即为树上\(u,v\)路径长),对于那些路径长已经超过二分距离\(x\)的计划,我们必须能找到一条公共边 且 删去此边后最长距离小于等于\(x\),否则无法满足\(x\),判定当前\(x\)为假。
之后考虑如何快速求得公共边,使用树上差分,具体将边塞入较深的点,cnt[u]+=1,cnt[v]+=1,cnt[lca(u,v)]-=2
,最后再回溯时求树上前缀和即可求得边覆盖次数。一条边是公共边当且仅当覆盖次数为路径个数。
综上,判定时找到超过\(x\)的路径的最长的公共边后,判断删去后是否能满足\(x\)即可
其实不是很毒瘤,相比于瘟疫控制
#include <cstdio>
#include <algorithm>
#include <cstring>
#define MAXN 300003
#define LOG 20
using namespace std;
inline int read(){
char ch=getchar();int s=0;
while(ch<'0'||ch>'9') ch=getchar();
while(ch>='0'&&ch<='9') s=s*10+(ch^'0'),ch=getchar();
return s;
}
int head[MAXN],nxt[MAXN*2],vv[MAXN*2],ww[MAXN*2],tot;
inline void add_edge(int u, int v, int w){
vv[++tot]=v;
ww[tot]=w;
nxt[tot]=head[u];
head[u]=tot;
}
int dis[MAXN];
int dep[MAXN],f[MAXN][LOG];
void dfs(int u, int fa){
dep[u]=dep[fa]+1;
f[u][0]=fa;
for(int i=1;i<LOG;++i)
f[u][i]=f[f[u][i-1]][i-1];
for(int i=head[u];i;i=nxt[i]){
int v=vv[i],w=ww[i];
if(v==fa) continue;
dis[v]=dis[u]+w;
dfs(v, u);
}
}
int lca(int a, int b){
if(dep[a]<dep[b]) swap(a,b);
for(int i=LOG-1;i>=0;--i)
if(dep[f[a][i]]>=dep[b])
a=f[a][i];
if(a==b) return a;
for(int i=LOG-1;i>=0;--i)
if(f[a][i]!=f[b][i])
a=f[a][i],b=f[b][i];
return f[a][0];
}
int n,m;
struct nod{
int u,v,lca;
}a[MAXN];
int d[MAXN],dmx;
int cnt[MAXN];
void calc(int u, int fa, const int num, int &res){
for(int i=head[u];i;i=nxt[i]){
int v=vv[i],w=ww[i];
if(v==fa) continue;
calc(v, u, num, res);
if(cnt[v]==num) res=max(res, w);
cnt[u]+=cnt[v];
}
}
bool check(int x){
memset(cnt, 0, sizeof cnt);
int num=0,res=0;
for(int i=1;i<=m;++i){
if(d[i]<=x) continue;
cnt[a[i].u]+=1;
cnt[a[i].v]+=1;
cnt[a[i].lca]-=2;
++num;
}
if(num==0) return x>=dmx;
calc(1, 0, num, res);
if(dmx-res<=x) return 1;
return 0;
}
int main(){
n=read(),m=read();
for(int i=1;i<n;++i){
int u=read(),v=read(),w=read();
add_edge(u, v, w);
add_edge(v, u, w);
}
dfs(1, 0);
for(int i=1;i<=m;++i) a[i].u=read(),a[i].v=read();
for(int i=1;i<=m;++i){
a[i].lca=lca(a[i].u, a[i].v);
d[i]=dis[a[i].u]+dis[a[i].v]-2*dis[a[i].lca];
dmx=max(dmx, d[i]);
}
int l=0,r=dmx;
int ans=-1;
while(l<=r){
int mid=(l+r)>>1;
if(check(mid)) ans=mid, r=mid-1;
else l=mid+1;
}
printf("%d", ans);
return 0;
}