洛谷2619/bzoj2654 Tree(凸优化+MST)
bzoj的数据是真的水。。
qwq
由于本人还有很多东西不是很理解
qwq
所以这里只写一个正确的做法。
首先,我们会发现,对于你选择白色边的数目,随着数目的上涨,斜率是单调升高的。
那么这时候我们就可以考虑凸优化,也就是\(wqs\)二分来满足题目中所述的正好\(k\)条边的限制。
我们\(erf\)一个\(mid\),然后让每一个白边的权值都加上\(mid\),然后跑\(MST\),看最后的选的白色边数,是否是大于等于\(k\)的,如果是,就调大\(l\),否则调小\(r\)。
由于最小生成树选择边的时候可能有一些玄学的错误,所以我们在\(sort\)的时候,对于权值相等的边,我们优先选择白边。
那么通过\(erf\),之后,我们就能得到一个上界,也就是在当前的偏移量下,我们最多的选和1相连的边的个数。
根据\(clj\)的官方题解,这里有两个引理
对于一个图,如果存在一个最小生成树,它的白边的数量是\(x\),那么就称\(x\)是最小合法白边数。所有的最小合法白边数形成一个区间\([l,r]\)
(因为题目保证有解,所以我们只需要找到最小的\(r\)即可)
那么经过这个\(erf\),我们就能得到一个最小的\(r\)
那么我们应该怎么求整个\(MST\)的权值呢,我们会发现,对于权值相等的白边和黑边,由于题目保证有解,所以一定是会存在相互替代的关系的。
那我们可以按照之前的最小生成树的策略选白边,将其记为\(val\),最后输出\(val-k*ans\),\(ans\)表示最后的\(mid\)。
为什么是\(k\)而不是具体的选的边的数目呢?
因为题目要求正好选择\(k\)条,而我们这里实际上是把多余的白边都直接视为黑边来做了
qwqwq
那么这个题就能解决了
qwqwqwqwq
但是我根据CF125E那个题,有一个比较特殊的做法,但是套到这个这个题,我并不是很理解。qwq
这个坑还是之后再填吧
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<queue>
#include<vector>
#include<map>
#include<vector>
#define mk make_pair
#define pb push_back
#define ll long long
#define int long long
using namespace std;
inline int read()
{
int x=0,f=1;char ch=getchar();
while (!isdigit(ch)){if (ch=='-') f=-1;ch=getchar();}
while (isdigit(ch)){x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
return x*f;
}
const int maxn = 4e5+1e2;
struct Edge{
int u,v,w;
int col;
};
Edge e[maxn];
int n,m;
int ans;
int l=-200,r=200;
int fa[maxn];
int find(int x)
{
if (fa[x]!=x) fa[x]=find(fa[x]);
return fa[x];
}
int k;
bool cmp(Edge a,Edge b)
{
if (a.w==b.w) return a.col<b.col;
return a.w<b.w;
}
int solve()
{
sort(e+1,e+1+m,cmp);
int tot=0;
for (int i=1;i<=m;i++)
{
int f1 = find(e[i].u);
int f2 = find(e[i].v);
if (f1==f2) continue;
//if(tot==k && e[i].col==0) continue;
if (e[i].col==0) ++tot;
fa[f1]=fa[f2];
}
return tot;
}
signed main()
{
n=read(),m=read();k=read();
for (int i=1;i<=m;i++)
{
e[i].u=read()+1;
e[i].v=read()+1;
e[i].w=read();
e[i].col=read();
}
while(l<=r)
{
int mid = (l+r) >> 1;
for (int i=1;i<=n;i++) fa[i]=i;
for (int i=1;i<=m;i++)
{
if (e[i].col==0) e[i].w+=mid;
}
int tmp = solve();
if (tmp<k)
{
r=mid-1;
}
else l=mid+1,ans=mid;
for (int i=1;i<=m;i++)
{
if (e[i].col==0) e[i].w-=mid;
}
}
for (int i=1;i<=n;i++) fa[i]=i;
for (int i=1;i<=m;i++)
if (e[i].col==0) e[i].w+=ans;
sort(e+1,e+1+m,cmp);
int tot=0,val=0;
for (int i=1;i<=m;i++)
{
int f1 = find(e[i].u);
int f2 = find(e[i].v);
if (f1==f2) continue;
if (e[i].col==0) ++tot;
fa[f1]=fa[f2];
val+=e[i].w;
}
cout<<val-k*ans;
return 0;
}