KM(Kuhn-Munkres Algorithm)

KM(Kuhn-Munkres Algorithm)

二分图最大权完美匹配。

问题引入

给定一张二分图,左右部均有 \(n\) 个点,共有 \(m\) 条带权边,且保证有完美匹配。

求一种完美匹配的方案,使得最终匹配边的边权之和最大。

定义

可行顶标

每个结点分配一个权值 \(l(i)\),对于所有边 \((u, v)\) 满足 \(w(u, v) \leq l(u) + l(v)\)

相等子图

包含所有点,但是只包含满足 \(w(u, v) = l(u) + l(v)\) 的边 \((u, v)\) 的子图称作相等子图

定理 1

对于某组可行顶标,如果其相等子图存在完美匹配,那么,该匹配就是原二分图的最大权完美匹配。

下文证明为了方便,默认左右部共有 \(2n\) 个点,编号为 \(1 \sim n\) 的认为是左部点,\(n + 1 \sim 2n\) 认为是右边部点。

证明如下:

考虑原二分图任意一组完美匹配 \(M\),其边权和为:

\[\operatorname {val}(M) = \sum_{(u, v) \in M} w(u, v) \leq \sum_{(u, v) \in M} (l(u) + l(v)) = \sum_{i = 1}^{2n} l(i) \]

然而,对于任意一组可行顶标相等子图的完美匹配 \(M'\),其边权和为:

\[\operatorname {val}(M') = \sum_{(u, v) \in M'} w(u, v) = \sum_{(u, v)\in M'}(l(u) + l(v)) = \sum_{i=1}^{2n}l(i) \]

故,完美匹配 \(M'\) 一定是该二分图的最大权完美匹配。

根据定理 1,我们有了一个简单的思路:

不断调整可行顶标,使得其相等子图存在完美匹配。又因为保证有完美匹配,所以,一定存在一组可行顶标使得其相等子图完美匹配

一些变量定义

\(lx(i)\) 表示左部点 \(i\) 的顶标,\(ly(i)\) 表示右部点 \(i\) 的顶标,\(w(u, v)\) 表示左部点 \(u\) 和右部点 \(v\) 之间的边权。

\(vx(i)\) 表示左部点 \(i\) 遍历标记,\(vy(i)\) 表示右部点 \(i\) 遍历标记。

\(px(i)\) 表示左部点 \(i\) 的匹配,\(py(i)\) 表示右部点 \(i\) 的匹配。

流程

假设此时左侧有点 \(x\) 在相等子图中的最大匹配中为非匹配点,从 \(x\) 开始尝试在相等子图中寻找增广路(由于是最大匹配了,所以肯定找不到),然后我们将访问到的左部点顶标减小 \(\Delta\),右部点顶标增大 \(\Delta\),考虑这样做的影响:

  1. 对于匹配边 \((u, v)\),显然,\(u, v\) 都会被 / 都不会被访问到,故 \(lx(u) + ly(v)\) 不变。
  2. 对于某个以访问到的左部点 \(u\) 为一端的非匹配边,由于 \(lx(u)\) 减小,故这条非匹配边可能加入到相等子图中。

故这么做既不会影响匹配边,还能将非匹配边加入相等子图,从而继续增广,而我们只要取 \(\Delta\) 为所有左部被访问到的点为一端的边 \((u, v)\) 中最小的 \(lx(u) + ly(v) - w(u, v)\) 即可,这样,至少加入了一条边(若取其他值取到 \(\min\) 的边会不满足顶标的性质)。

首先可以初始化左右部点的值:

\[lx(i) = \max_{1 \leq j \leq n} w(i, j) \\ ly(i) = 0 \]

每次加入一个左部点,按照上述在相等子图中进行增广,直到其加入最大匹配中。直接使用 DFS 实现的匈牙利算法,时间复杂度为 \(\mathcal O(n^4)\)

#include <bits/stdc++.h>
const long long inf = 1e18;
const int N = 500 + 10;
int n, m, w[N][N];
bool exist[N][N];
long long lx[N], ly[N];
int py[N], px[N];
long long d; bool vx[N], vy[N];

inline bool dfs(int u)
{
    vx[u] = 1;

    for (int i = 1; i <= n; ++ i)
        if (exist[u][i] && ! vy[i])
        {
            if (lx[u] + ly[i] == w[u][i])
            {
                vy[i] = 1;
                if (! py[i] || dfs(py[i]))
                {
                    px[u] = i; py[i] = u;
                    return 1;
                }
            }
            else d = std::min(d, 1ll * lx[u] + ly[i] - w[u][i]);
        }
    return 0;
}
int main()
{
    std::cin >> n >> m;

    for (int i = 1; i <= m; ++ i)
    {
        int u, v; std::cin >> u >> v;
        std::cin >> w[u][v];
        exist[u][v] = 1;
    }
    for (int i = 1; i <= n; ++ i)
    {
        lx[i] = - inf;
        for (int j = 1; j <= n; ++ j)
            if (exist[i][j])
                lx[i] = std::max(lx[i], 1ll * w[i][j]);
    }
    for (int i = 1; i <= n; ++ i)
    {
        while (1)
        {
            memset(vx, 0, sizeof(vx));
            memset(vy, 0, sizeof(vy));
            d = inf;
            if (dfs(i)) break;
            for (int j = 1; j <= n; ++ j)
            {
                if (vx[j]) lx[j] -= d;
                if (vy[j]) ly[j] += d;
            }
        }
    }
    long long ans = 0;
    for (int i = 1; i <= n; ++ i)
        ans += lx[i] + ly[i];
    std::cout << ans << std::endl;

    for (int i = 1; i <= n; ++ i)
        std::cout << py[i] << ' ';
    return 0;
}

考虑优化,在每一次加入一个左部点,尝试增广时,模拟匈牙利算法求增广路的过程,对于右部每个点 \(v\),记录 \(slack(v)\) 表示这一轮已经访问的左部点 \(u\)\(\min\{lx(u) + ly(v) - w(u, v)\}\) 的值。

当我们访问到一个左部点时,先用它更新所有右部点的 \(slack\) 值,接着取出右部 \(slack\) 最小的点,将其值设为 \(\Delta\),然后将当前已经访问到的左部点顶标减小 \(\Delta\),当前已经访问的右部点顶标增大 \(\Delta\),并更新 \(slack\) 数组。

下一个访问的左部点将是刚才取出的右部点的匹配点,这就是一个寻找增广路的过程,而当那个右部点是非匹配点时,我们已经找到了一条增广路。

重复上述过程,依次加入所有左部点,最后就求出了最小顶标和问题的解,也就是最大权完美匹配的方案。

#include <bits/stdc++.h>
const long long inf = 1e18;
const int N = 500 + 10;
int n, m, w[N][N];
bool exist[N][N];
long long lx[N], ly[N];
int px[N], py[N], pre[N];
long long slack[N];
bool vx[N], vy[N];
std::queue < int > Q;
inline void match(int v)
{
    int t = 0;
    while (v)
    {
        t = px[pre[v]];
        px[pre[v]] = v;
        py[v] = pre[v];
        v = t;
    }
}
inline void bfs(int s)
{
    memset(vx, 0, sizeof(vx));
    memset(vy, 0, sizeof(vy));
    for (int i = 1; i <= n; ++ i)
        slack[i] = inf;
    while (Q.size()) Q.pop();
    Q.push(s);

    while (1)
    {
        while (Q.size())
        {
            int u = Q.front(); Q.pop();
            vx[u] = 1;
            for (int i = 1; i <= n; ++ i)
                if (exist[u][i] && ! vy[i])
                {
                    if (lx[u] + ly[i] - w[u][i] < slack[i])
                    {
                        slack[i] = lx[u] + ly[i] - w[u][i];
                        pre[i] = u;
                        if (slack[i] == 0)
                        {
                            vy[i] = 1;
                            if (! py[i])
                            {
                                match(i);
                                return;
                            }
                            else Q.push(py[i]);
                        }
                    }
                }
        }
        long long d = inf;
        for (int i = 1; i <= n; ++ i)
            if (! vy[i]) d = std::min(d, slack[i]);
        for (int i = 1; i <= n; ++ i)
        {
            if (vx[i]) lx[i] -= d;
            if (vy[i]) ly[i] += d;
            else slack[i] -= d;
        }
        for (int i = 1; i <= n; ++ i)
            if (! vy[i] && slack[i] == 0)
            {
                vy[i] = 1;
                if (! py[i])
                {
                    match(i);
                    return;
                }
                else Q.push(py[i]);
            }
    }
}
int main()
{
    std::cin >> n >> m;

    for (int i = 1; i <= m; ++ i)
    {
        int u, v; std::cin >> u >> v;
        std::cin >> w[u][v];
        exist[u][v] = 1;
    }
    for (int i = 1; i <= n; ++ i)
    {
        lx[i] = - inf;
        for (int j = 1; j <= n; ++ j)
            if (exist[i][j])
                lx[i] = std::max(lx[i], 1ll * w[i][j]);
    }
    for (int i = 1; i <= n; ++ i)
        bfs(i);

    long long ans = 0;
    for (int i = 1; i <= n; ++ i)
        ans += lx[i] + ly[i];
    std::cout << ans << std::endl;

    for (int i = 1; i <= n; ++ i)
        std::cout << py[i] << ' ';
    return 0;
}

Reference

  1. OI-wiki
  2. ix35
posted @ 2022-03-14 00:07  chzhc  阅读(277)  评论(0编辑  收藏  举报
levels of contents