二分图最大权完美匹配(KM算法)
KM算法
KM算法用来解决二分图最大权完美匹配问题,是对匈牙利算法算法的改进
具体思路是,假设左边的节点有一个顶标\(ex_i\),右边的节点有一个顶标\(ey_i\)
我们要求顶标对于任意一条边,满足\(ex_i+ey_j \ge w_{i,j}\),当\(ex_i+ey_j=w_{i,j}\),可以连接边\(edge(i,j)\)。所以取到二分图完美匹配时,边权之和为\(\sum_{i=1}^n ex_i+\sum_{i=1}^n ey_i\)
我们需要在满足完美匹配的同时,尽量让顶标和最大
首先我们先寻找顶标和上限(这时候不要求完美匹配),很容易找到可行解
显然不可能有比这个更优的解
但是不一定存在这样的解,我们先不处理,像原来一样一个一个点寻找增广路,只不过经过一条边\(edge(i,j)\)的条件多了一个限制\(ex_i+ey_j=w_{i,j}\)
然后我们发现没有完全匹配怎么办?
定义:\(matched_x\)为\(x\)集中已匹配的点,\(matched_y\)为\(y\)集中已匹配的点,\(not\_matched_y\)为\(y\)集中未匹配的点
我们需要降低顶标和,但是为了求最大权完美匹配,我们必须让顶标减小的值尽可能小,也就是说,我们要找到一个\(not\_matched_y\)集中的点\(j\),满足\(\min ex_i+ey_j-w_{i,j}(i \in matched_x)\)值最小,然后让这条边也加入可能的边集中,这样可以保证不会漏掉更优答案
显然,我们降低顶标后,原来\(matched_y\)集中的点仍然可以匹配,同时有更多的点可以被匹配
设顶标减小值为\(del\),我们要扩大可匹配边集
首先,\(\forall i \in match_x,ex_i-=del\),表明\(match_x\)降低了标准,可以去匹配差为\(ex_i+ey_j-w_{i,j}=del\)的较次点了,同时由于\(\forall i \in match_y,ey_i+=del\),\(match_x\)集中点与\(match_y\)集中点的和不变,仍然可以匹配
如何求\(del\)呢,暴力求太慢,记\(slack_i(i \in not\_matched_y)\)为与\(i\)连边的\(del\)的最小值,可以在匹配过程中同时处理,具体见代码
\(dfs\)版\(KM\)
时间复杂度:\(O(n^2 m)\)
#include<iostream>
#include<cstdio>
#include<algorithm>
#define N 505
#define M 250005
#define INF 9990365505
#define ll long long
using namespace std;
int n,m,x,y,z,t,tim=0,tot,visx[N],visy[N],fr[N],mch[N],nxt[M],d1[M],d2[M];
ll ex[N],ey[N],slack[N];
void add(int x,int y,int z)
{
tot++;
d1[tot]=y;
d2[tot]=z;
nxt[tot]=fr[x];
fr[x]=tot;
}
bool dfs(int u)
{
if (visx[u]==tim)
return false;
visx[u]=tim;
for (int i=fr[u];i;i=nxt[i])
{
int v=d1[i];
ll del=ex[u]+ey[v]-d2[i];
if (!del)
{
visy[v]=tim;
if (!mch[v] || dfs(mch[v]))
{
mch[v]=u;
return true;
}
} else
slack[v]=min(slack[v],del);
}
return false;
}
void KM()
{
for (t=1;t<=n;t++)
{
for (int i=1;i<=n;i++)
slack[i]=INF;
for (;;)
{
tim++;
if (dfs(t))
break;
ll del=INF;
for (int i=1;i<=n;i++)
if (visy[i]!=tim)
del=min(del,slack[i]);
for (int i=1;i<=n;i++)
{
if (visx[i]==tim)
ex[i]-=del;
if (visy[i]==tim)
ey[i]+=del; else
slack[i]-=del;
}
}
}
ll ans=0;
for (int i=1;i<=n;i++)
ans+=ex[i]+ey[i];
printf("%lld\n",ans);
for (int i=1;i<=n;i++)
printf("%d ",mch[i]);
putchar('\n');
}
int main()
{
scanf("%d%d",&n,&m);
for (int i=1;i<=n;i++)
ex[i]=-INF,ey[i]=0;
for (int i=1;i<=m;i++)
{
scanf("%d%d%d",&x,&y,&z);
add(x,y,z);
ex[x]=max(ex[x],(ll)z);
}
KM();
return 0;
}
结果\(T\)了。。。
我们可以发现,在\(dfs\)中,每降低一次顶标,我们又开始从头开始搜索增广路了,有大量的重复信息,其实我们只需要遍历那些新加入的\(slack_i=0(i \in not\_matched_y)\)的节点就好了,因为它们才是可能找到增广路的节点
于是我们有了\(bfs\)的做法
在\(bfs\)过程中,如果\(slack\)要修改,那么我们直接修改它的前驱就好了,因为前驱不会对接下来的路径产生干扰
每个节点最多入队一次,时间复杂度\(O(m)\)
\(bfs\)版\(KM\)
时间复杂度:\(O(nm)\)
#include<iostream>
#include<cstdio>
#include<algorithm>
#define N 505
#define M 250005
#define INF 9990365505
#define ll long long
using namespace std;
int n,m,x,y,z,tot,tim,l,r,q[N],fr[N],nxt[M],d1[M],d2[M];
int pre[N],visx[N],visy[N],mchx[N],mchy[N];
ll ex[N],ey[N],slack[N];
void add(int x,int y,int z)
{
tot++;
d1[tot]=y;
d2[tot]=z;
nxt[tot]=fr[x];
fr[x]=tot;
}
void modify(int cur)
{
for (int x=cur,lst;x;x=lst)
lst=mchx[pre[x]],mchx[pre[x]]=x,mchy[x]=pre[x];
}
void bfs(int cur)
{
for (int i=1;i<=n;i++)
slack[i]=INF,pre[i]=0;
l=1,r=0;
q[++r]=cur;
++tim;
for (;;)
{
while (l<=r)
{
int u=q[l];
l++;
visx[u]=tim;
for (int i=fr[u];i;i=nxt[i])
{
int v=d1[i],cost=d2[i];
if (visy[v]==tim)
continue;
ll del=ex[u]+ey[v]-cost;
if (del<slack[v])
{
slack[v]=del;
pre[v]=u;
if (!del)
{
visy[v]=tim;
if (!mchy[v])
{
modify(v);
return;
}
q[++r]=mchy[v];
}
}
}
}
ll del=INF;
for (int i=1;i<=n;i++)
if (visy[i]!=tim)
del=min(del,slack[i]);
l=1,r=0;
for (int i=1;i<=n;i++)
{
if (visx[i]==tim)
ex[i]-=del;
if (visy[i]==tim)
ey[i]+=del; else
slack[i]-=del;
}
for (int i=1;i<=n;i++)
if (visy[i]!=tim && !slack[i])
{
visy[i]=tim;
if (!mchy[i])
{
modify(i);
return;
}
q[++r]=mchy[i];
}
}
}
void KM()
{
for (int i=1;i<=n;i++)
bfs(i);
ll ans=0;
for (int i=1;i<=n;i++)
ans+=ex[i]+ey[i];
printf("%lld\n",ans);
for (int i=1;i<=n;i++)
printf("%d ",mchy[i]);
putchar('\n');
}
int main()
{
scanf("%d%d",&n,&m);
for (int i=1;i<=n;i++)
ex[i]=-INF,ey[i]=0;
for (int i=1;i<=m;i++)
{
scanf("%d%d%d",&x,&y,&z);
add(x,y,z);
ex[x]=max(ex[x],(ll)z);
}
KM();
return 0;
}