luogu5021题解

形式化题意:在一棵树上找 \(m\) 条没有重复边的路径,使得最短的路径最大,求这个最短路径的最大值。

看到这个最大值就想二分答案。

\(1\le m\le n-1\),我们可以将长度的下限为最短的那条边,此时所有边都是符合要求的路径。
长度的上限为所有路径的长度除以 \(m\),向下取整。

我们在判断这个长度是否可以找出 \(m\) 条路径的过程中,用一个贪心的方法来使选取的边最多,然后判断边的数量是否足够即可。

当一棵树为菊花树时,我们令度最大的点为根节点,可以发现路径一定是根节点到子节点的一条边或是两条边拼起来。
一条边的情况需要该边大于等于要求的长度,满足这个条件的边没必要和其他边拼起来,直接作为一条路径即可。
除去大于等于要求长度的边后,剩下的边都不构成路径或是两条边的情况。
为使求得的路径最多,对于每条边我们都尽可能选择可以组成符合要求的路径但长度尽量小的边,因为长度大的边可能可以与更小的边组成符合要求的路径,但长度小的边可能不可以。
我们可以证明在菊花图上以任何顺序选取边并采取贪心策略的正确性。证明太长了拿出来水篇闲话(

如果该树不是菊花树时,会有非根的非叶节点。
此时可以有一条边与返祖的边连成路径。
不过与返祖边连成路径最多会使得路径多一条,在某节点各个向下的边连成路径的方案一定不劣于与返祖边连成路径。所以我们在某节点统计完路径选取剩下的边中最长的一边与上面的边连成路径即可。
此过程需要所在节点与子节点的边中从长度小的边开始选取,使得选取的与返祖边组成路径的边最长。
被选择的这条边在父节点可以将父节点与该子节点的边连起来,看作一条新边。
这样我们遍历一遍就可以解决这个问题了。

在一个节点统计路径与剩余最长路径的过程中和这篇博客类似,先将所有的边(我们把子节点遍历后返回的最长路径与该节点到子节点的边看作这里的边)排序,然后用 \(O(n)\) 的方法扫了一遍。
对于原本就大于等于 \(mid\) 的边直接累加进符合要求路径数中。

剩下的边从小的开始选,将符合条件的边由大到小入栈,这个过程用双指针实现,若栈非空则小的边与栈中的边合起来,弹栈并使符合要求路径数加 \(1\),否则更新向上返回的路径长度。

最后可能还有一些边留着栈中,由于这些边可以比其更小的边拼成一个符合要求的路径,所以其中任意两条边都可以拼成符合要求的路径,我们两个两个的弹栈并使符合要求路径数加 \(1\),最后如果还有边就更新向上返回的路径长度。

时间复杂度受排序所限为 \(O(n\log n)\),不过常数比 std::multiset 要小。

照着题解写的 std::multiset代码确实是慢了一些。
代码如下。

#include<bits/stdc++.h>
using namespace std;
#define pb push_back
#define fi first
#define se second
typedef pair<int,int> pii;
constexpr int MAXN=5e4+10;
int n,m,mid;
vector<pii> e[MAXN];
void add(int u,int v,int w){
    e[u].pb({v,w});
}
pii dfs(int u,int fa){
    int retfi=0,retse=0;
    vector<int> te;
    for(pii i:e[u]){
        if(i.fi==fa)continue;
        pii temp=dfs(i.fi,u);
        retfi+=temp.fi;
        int w=i.se+temp.se;
        if(w>=mid)++retfi;
        else te.pb(w);
    }
    sort(te.begin(),te.end());
    int pf=0,pl=te.size()-1;
    stack<int> s;
    while(pf<=pl){
        while(pf<pl&&te[pl]+te[pf]>=mid){
            s.push(te[pl]);--pl;
        }
        if(!s.empty()){
            ++retfi;s.pop();
        }else retse=te[pf];
        ++pf;
    }
    while(s.size()>=2){
        ++retfi;s.pop();s.pop();
    }
    if(!s.empty())retse=s.top();
    return {retfi,retse};
}
bool check(){
    return dfs(1,0).fi>=m;
}
int main(){
    scanf("%d%d",&n,&m);
    int tot=0;
    for(int i=1,a,b,l;i<n;++i){
        scanf("%d%d%d",&a,&b,&l);
        add(a,b,l);add(b,a,l);tot+=l;
    }
    int l=1,r=tot/m;
    while(l<r){
        mid=(l+r+1)>>1;
        if(check()){
            l=mid;
        }else{
            r=mid-1;
        }
    }
    printf("%d\n",l);
    return 0;
}
posted @ 2024-02-20 21:40  LiJoQiao  阅读(12)  评论(2编辑  收藏  举报