二分图最大权完美匹配(KM算法)

KM算法

KM算法用来解决二分图最大权完美匹配问题,是对匈牙利算法算法的改进

具体思路是,假设左边的节点有一个顶标\(ex_i\),右边的节点有一个顶标\(ey_i\)

我们要求顶标对于任意一条边,满足\(ex_i+ey_j \ge w_{i,j}\),当\(ex_i+ey_j=w_{i,j}\),可以连接边\(edge(i,j)\)。所以取到二分图完美匹配时,边权之和为\(\sum_{i=1}^n ex_i+\sum_{i=1}^n ey_i\)

我们需要在满足完美匹配的同时,尽量让顶标和最大

首先我们先寻找顶标和上限(这时候不要求完美匹配),很容易找到可行解

\[\forall i \in x,ex_i=\max_{edge(i,j)} w_{i,j}\\ \forall i \in y,ey_i=0 \]

显然不可能有比这个更优的解

但是不一定存在这样的解,我们先不处理,像原来一样一个一个点寻找增广路,只不过经过一条边\(edge(i,j)\)的条件多了一个限制\(ex_i+ey_j=w_{i,j}\)

然后我们发现没有完全匹配怎么办?

定义:\(matched_x\)\(x\)集中已匹配的点,\(matched_y\)\(y\)集中已匹配的点,\(not\_matched_y\)\(y\)集中未匹配的点

我们需要降低顶标和,但是为了求最大权完美匹配,我们必须让顶标减小的值尽可能小,也就是说,我们要找到一个\(not\_matched_y\)集中的点\(j\),满足\(\min ex_i+ey_j-w_{i,j}(i \in matched_x)\)值最小,然后让这条边也加入可能的边集中,这样可以保证不会漏掉更优答案

显然,我们降低顶标后,原来\(matched_y\)集中的点仍然可以匹配,同时有更多的点可以被匹配

设顶标减小值为\(del\),我们要扩大可匹配边集

\[\forall i \in match_x,ex_i-=del\\ \forall i \in match_y,ey_i+=del\\ \]

首先,\(\forall i \in match_x,ex_i-=del\),表明\(match_x\)降低了标准,可以去匹配差为\(ex_i+ey_j-w_{i,j}=del\)的较次点了,同时由于\(\forall i \in match_y,ey_i+=del\)\(match_x\)集中点与\(match_y\)集中点的和不变,仍然可以匹配

如何求\(del\)呢,暴力求太慢,记\(slack_i(i \in not\_matched_y)\)为与\(i\)连边的\(del\)的最小值,可以在匹配过程中同时处理,具体见代码

\(dfs\)\(KM\)

时间复杂度:\(O(n^2 m)\)

#include<iostream>
#include<cstdio>
#include<algorithm>
#define N 505
#define M 250005
#define INF 9990365505
#define ll long long
using namespace std;
int n,m,x,y,z,t,tim=0,tot,visx[N],visy[N],fr[N],mch[N],nxt[M],d1[M],d2[M];
ll ex[N],ey[N],slack[N];
void add(int x,int y,int z)
{
    tot++;
    d1[tot]=y;
    d2[tot]=z;
    nxt[tot]=fr[x];
    fr[x]=tot;
}
bool dfs(int u)
{
    if (visx[u]==tim)
        return false;
    visx[u]=tim;
    for (int i=fr[u];i;i=nxt[i])
    {
        int v=d1[i];
        ll del=ex[u]+ey[v]-d2[i];
        if (!del)
        {
            visy[v]=tim;
            if (!mch[v] || dfs(mch[v]))
            {
                mch[v]=u;
                return true;
            }
        } else
            slack[v]=min(slack[v],del);
    }
    return false;
}
void KM()
{
    for (t=1;t<=n;t++)
    {
        for (int i=1;i<=n;i++)
            slack[i]=INF;
        for (;;)
        {
            tim++;
            if (dfs(t))
                break;
            ll del=INF;
            for (int i=1;i<=n;i++)
                if (visy[i]!=tim)
                    del=min(del,slack[i]);
            for (int i=1;i<=n;i++)
            {
                if (visx[i]==tim)
                    ex[i]-=del;
                if (visy[i]==tim)
                    ey[i]+=del; else
                    slack[i]-=del;
            }
        }
    }
    ll ans=0;
    for (int i=1;i<=n;i++)
        ans+=ex[i]+ey[i];
    printf("%lld\n",ans);
    for (int i=1;i<=n;i++)
        printf("%d ",mch[i]);
    putchar('\n');
}
int main()
{
    scanf("%d%d",&n,&m);
    for (int i=1;i<=n;i++)
        ex[i]=-INF,ey[i]=0;
    for (int i=1;i<=m;i++)
    {
        scanf("%d%d%d",&x,&y,&z);
        add(x,y,z);
        ex[x]=max(ex[x],(ll)z);
    }
    KM();
    return 0;
}

结果\(T\)了。。。

我们可以发现,在\(dfs\)中,每降低一次顶标,我们又开始从头开始搜索增广路了,有大量的重复信息,其实我们只需要遍历那些新加入的\(slack_i=0(i \in not\_matched_y)\)的节点就好了,因为它们才是可能找到增广路的节点

于是我们有了\(bfs\)的做法

\(bfs\)过程中,如果\(slack\)要修改,那么我们直接修改它的前驱就好了,因为前驱不会对接下来的路径产生干扰

每个节点最多入队一次,时间复杂度\(O(m)\)

\(bfs\)\(KM\)

时间复杂度:\(O(nm)\)

#include<iostream>
#include<cstdio>
#include<algorithm>
#define N 505
#define M 250005
#define INF 9990365505
#define ll long long
using namespace std;
int n,m,x,y,z,tot,tim,l,r,q[N],fr[N],nxt[M],d1[M],d2[M];
int pre[N],visx[N],visy[N],mchx[N],mchy[N];
ll ex[N],ey[N],slack[N];
void add(int x,int y,int z)
{
    tot++;
    d1[tot]=y;
    d2[tot]=z;
    nxt[tot]=fr[x];
    fr[x]=tot;
}
void modify(int cur)
{
    for (int x=cur,lst;x;x=lst)
        lst=mchx[pre[x]],mchx[pre[x]]=x,mchy[x]=pre[x];
}
void bfs(int cur)
{
    for (int i=1;i<=n;i++)
        slack[i]=INF,pre[i]=0;
    l=1,r=0;
    q[++r]=cur;
    ++tim;
    for (;;)
    {
        while (l<=r)
        {
            int u=q[l];
            l++;
            visx[u]=tim;
            for (int i=fr[u];i;i=nxt[i])
            {
                int v=d1[i],cost=d2[i];
                if (visy[v]==tim)
                    continue;
                ll del=ex[u]+ey[v]-cost;
                if (del<slack[v])
                {
                    slack[v]=del;
                    pre[v]=u;
                    if (!del)
                    {
                        visy[v]=tim;
                        if (!mchy[v])
                        {
                            modify(v);
                            return;
                        }
                        q[++r]=mchy[v];
                    }
                }
            }
        }
        ll del=INF;
        for (int i=1;i<=n;i++)
            if (visy[i]!=tim)
                del=min(del,slack[i]);
        l=1,r=0;
        for (int i=1;i<=n;i++)
        {
            if (visx[i]==tim)
                ex[i]-=del;
            if (visy[i]==tim)
                ey[i]+=del; else
                slack[i]-=del;
        }
        for (int i=1;i<=n;i++)
            if (visy[i]!=tim && !slack[i])
            {
                visy[i]=tim;
                if (!mchy[i])
                {
                    modify(i);
                    return;
                }
                q[++r]=mchy[i];
            }
    }
}
void KM()
{
    for (int i=1;i<=n;i++)
        bfs(i);
    ll ans=0;
    for (int i=1;i<=n;i++)
        ans+=ex[i]+ey[i];
    printf("%lld\n",ans);
    for (int i=1;i<=n;i++)
        printf("%d ",mchy[i]);
    putchar('\n');
}
int main()
{
    scanf("%d%d",&n,&m);
    for (int i=1;i<=n;i++)
        ex[i]=-INF,ey[i]=0;
    for (int i=1;i<=m;i++)
    {
        scanf("%d%d%d",&x,&y,&z);
        add(x,y,z);
        ex[x]=max(ex[x],(ll)z);
    }
    KM();
    return 0;
}
posted @ 2020-09-15 18:17  GK0328  阅读(524)  评论(0编辑  收藏  举报