P2680 运输计划

P2680 运输计划

高级树上差分题

经过观察得知,我们要做的是:

清零一条边,使得给定的K条路径的最大值最小

看到最大值最小,想到二分答案

接着,我们对于所有>mid的边,找到它们的公共边

公共边找法:树上差分(当然高级数据结构也可以)

然后判断对于每一条公共边,清零后可不可以成功即可

时间复杂度:O(nlogn)

代码:

#include<bits/stdc++.h>
using namespace std;
const int N=300005;
struct node{
    int st,ed,len,lca;
}s[N];
int n,m;
int hed[N<<1],tal[N<<1],val[N<<1],nxt[N<<1],cnt=0;
int dep[N]={0};
int f[N][30];
int dp[N];
int MAXK=0;
int Nnt=0;
int num[N];
int dis[N]={0};
void addege(int x,int y,int z){
    cnt++;
    tal[cnt]=y;
    val[cnt]=z;
    nxt[cnt]=hed[x];
    hed[x]=cnt;
}
void dfs(int u,int father){
    Nnt++;
    num[Nnt]=u;
    f[u][0]=father;
    for(int i=hed[u];i;i=nxt[i]){
        int v=tal[i];
        if(v==father) continue;
        dep[v]=dep[u]+1;
        dis[v]=dis[u]+val[i];
        dfs(v,u);
    }
}
void Dfs(int u,int father){
    for(int i=hed[u];i;i=nxt[i]){
        int v=tal[i];
        if(v==father) continue;
        Dfs(v,u);
        dp[u]+=dp[v];
    }
}
bool cmp(node x,node y){
    return x.len>y.len;
}
int LCA(int u,int v){
    if(dep[u]<dep[v]) swap(u,v);
    for(int i=28;i>=0;i--) if(dep[f[u][i]]>=dep[v]) u=f[u][i];
    if(u==v) return u;
    for(int i=28;i>=0;i--) if(f[u][i]!=f[v][i]) u=f[u][i],v=f[v][i];
    return f[u][0];
}
bool check(int mid){
    memset(dp,0,sizeof(dp));
    int uo=0;
    int pp=0;
    for(int i=1;i<=m;i++){
        if(s[i].len>mid){
            dp[s[i].st]++;
            dp[s[i].ed]++;
            dp[s[i].lca]-=2;
            uo++;
            pp=max(pp,s[i].len-mid);
        }
    }
    if(uo==0) return 1;
    for(int i=n;i>=1;i--) dp[f[num[i]][0]]+=dp[num[i]];
    for(int i=2;i<=n;i++) if(dp[i]==uo&&dis[i]-dis[f[i][0]]>=pp) return 1;
    return 0;
}
int main(){
    scanf("%d%d",&n,&m);
    for(int i=1;i<n;i++){
        int x,y,z;
        scanf("%d%d%d",&x,&y,&z);
        addege(x,y,z);
        addege(y,x,z);
        MAXK+=z;
    }
    dis[1]=1;
    dep[1]=1;
    dfs(1,1);
    for(int j=1;j<=28;j++){
        for(int i=1;i<=n;i++){
            f[i][j]=f[f[i][j-1]][j-1];
        }
    }
    for(int i=1;i<=m;i++){
        scanf("%d%d",&s[i].st,&s[i].ed);
        s[i].lca=LCA(s[i].st,s[i].ed);
        s[i].len=dis[s[i].st]+dis[s[i].ed]-2*dis[s[i].lca];
    }
    int l=0,r=MAXK,ans=-1;
    while(l<=r){
        int mid=(l+r)>>1;
        if(check(mid)) ans=mid,r=mid-1;
        else l=mid+1; 
    }
    printf("%d\n",ans);
    return 0;
}

 

posted @ 2019-11-10 16:24  QYJ060604  阅读(94)  评论(0编辑  收藏  举报