Kuhn-Munkres算法

\(Km\)

Kuhn-Munkres算法

一种用于进行二分图完全匹配的算法


\(pre\)技能

匈牙利算法及增广路

标顶

对于图\(G(U\cup V,E)\)。对于\(x\in U\),定义\(Lx_i\)。对于\(i\in V\)。定义\(Ly_i\)

这个玩意叫做标顶,是一种人为构造的数值。用于进行二分图完全匹配

可行标顶

对于所有的边,假设权值是\(W\),方向是\(x\to y\),则是算法执行中,恒定满足\(Lx_x+Ly_y\geq W\)

相等子图

相等子图包括\(U\cup V\),但只包括满足\(W=Lx_x+Ly_y\)的边

若相等子图有完全匹配,则原图有完全匹配

实现概括

通过更改标顶,找出相等子图。


实现具体

扩大相等子图

假设当前相等子图无法进行完全匹配,则至少有一个点\(u\),从其出发无法找到增广路

则我们需要修改标顶,使更多的边参与进来。

假设我们现在已经知道了一个\(M\),这个\(M\)是使其他边增加到相等子图的最小标顶修改量

然后我们令所有的\(Lx_i\)减去\(M\),所有的\(Ly_i\)加上\(M\)

正确性?

假设我们现在在相等子图中在\(U\)中已经被匹配的点\(x\),则我们规定\(x\in A\),否则\(x\in A'\)

相似的,对于\(x\in V\),若x已经被匹配上,则\(x\in B\),否则\(x\in B'\)

对于一条边(\(x\to y\),权值(代价)是\(W\))

  • \(x\in A,y\in B\),仍满足\(Lx_x+Ly_y=W\)
  • \(x\in A',y\in B'​\),则为\(Lx_x+Ly_y\geq W​\),即是原本就不在交替路中的的依旧不在
  • \(x\in A',y\in B\),则\(Lx_x+Ly_y\)增加,不会被添加进入
  • \(x\in A,y\in B'\),则\(Lx_x+Ly_y\)有所减少,可能会被添加
其他

关于

\(x\in A',y\in B\),则\(Lx_x+Ly_y\)增加,不会被添加进入

\(x\in A,y\in B'\),则\(Lx_x+Ly_y\)有所减少,可能会被添加

对于上边这句,是为了保证在将一个点\(x,x\in U\)进行匹配时,只会相应的引入\(y,y\in V\),而不会引入\(U\)中的点。

\(M\)如何计算?

\(M\)的大小取决于没有被加到相等子图中的最大边的大小,即是\(Lx_x+Ly_y-W\)

而这个我们在寻找增广路的时候就可以顺带更新。

其他

\(M\)的贪心选取保证了最大权。

个人理解是损失最小的代价,使其可以加入到相等子图

代码

using std::max;
using std::min;
const int maxn=500;
const int inf=0x7fffffff;
int M[maxn][maxn];
int m[maxn][maxn];
int lx[maxn],ly[maxn],mins[maxn],pat[maxn];
int S[maxn],T[maxn],tot;
int vis[maxn];
int n;
bool find(int x)
{
    S[x]=1;
    for(int i=1;i<=n;i++)
    {
        if(T[i])    continue;
        int s=lx[x]+ly[i]-m[x][i];
        if(!s)
        {
            T[i]=1;
            if(!pat[i]||find(pat[i]))
            {
                pat[i]=x;
                return true;
            }
        }
        else
            mins[i]=min(mins[i],s);
    }
    return false;
}
int KM()
{
    for(int i=1;i<=n;i++)   lx[i]=-inf;
    for(int i=1;i<=n;i++)
        for(int j=1;j<=n;j++)
            lx[i]=max(lx[i],m[i][j]),ly[i]=0;
    memset(pat,0,sizeof(pat));
    for(int i=1;i<=n;i++)
    {
        for(int j=1;j<=n;j++)
            mins[j]=inf;
        while(1)
        {
            memset(T,0,sizeof(T));memset(S,0,sizeof(S));
            if(find(i)) break;
            int Min=inf;
            for(int j=1;j<=n;j++)   Min=min(Min,mins[j]);
            for(int j=1;j<=n;j++)
            {
                if(S[j])    lx[j]-=Min;
                if(T[j])    ly[j]+=Min;
            }
        }
    }
    int ans=0;
    for(int i=1;i<=n;i++)
        ans+=m[pat[i]][i];
    return ans;
}
posted @ 2018-12-26 17:30  Lance1ot  阅读(951)  评论(0编辑  收藏  举报