二分图最大权完美匹配

洛谷P6577二分图最大权完美匹配(50分)后需bfs优化

因为边权取值可以为负,所以开始初始化为-INF;

ans为ll,km函数返回值为ll;

板子1:

#include <bits/stdc++.h>
using namespace std;
const int maxn = 555;
const int INF = 0x3f3f3f3f;
typedef long long ll;
int val[maxn][maxn], vis_A[maxn], vis_B[maxn], match[maxn];     //vis来记录已被匹配的,match记录B方匹配到了A方的哪个
int ex_A[maxn], ex_B[maxn], slack[maxn], m, n, e;                        //记录A方和B方的期望, A方有m个,B方有n个
//slack 任意一个参与匹配A方能换到任意一个这轮没有被选择过的B方所需要降低的最小值
//这是二分图的最优匹配(首先是A集合的完备匹配,然后保证权值最大)
//所以一定保证 m <= n, 否则会陷入死循环,若是A集合点多的话可以把B集合补充到和A一样多,设置-INF的边
bool dfs(int x)
{
    vis_A[x] = 1;
    for(int i = 1; i <= n; i++)
    {
        if(!vis_B[i])                             //每一轮匹配,B方每一个点只匹配一次
        {
            int gap = ex_A[x] + ex_B[i] - val[x][i];
            if(gap == 0)                         //如果符合要求
            {
                vis_B[i] = 1;
                if(!match[i] || dfs(match[i]))   //如果v尚未匹配或者匹配了可以被挪走
                {
                    match[i] = x;
                    return true;
                }
            }
            else slack[i] = min(slack[i], gap);
        }
    }
    return false;
}

ll km()
{
    memset(match, 0, sizeof(match));            //match为0表示还没有匹配
    fill(ex_B + 1, ex_B + 1 + n, 0);             //B方一开始期望初始化为0
    for(int i = 1; i <= m; i++)                 //A方期望取最大值
    {
        ex_A[i] = val[i][1];
        for(int j = 2; j <= n; j++)
            ex_A[i] = max(ex_A[i], val[i][j]);
    }
    for(int i = 1; i <= m; i++)               //尝试解决A方的每一个节点
    {
        memset(slack + 1, INF, sizeof(slack[0]) * n);
        for(;;)
        {
            memset(vis_A + 1, 0, sizeof(vis_A[0]) * m);     //记录AB双方有无被匹配过
            memset(vis_B + 1, 0, sizeof(vis_B[0]) * n);
            if(dfs(i))  break;
            int d = INF;
            for(int j = 1; j <= n; j++)    if(!vis_B[j]) d = min(d, slack[j]);
            //if(d == INF)  break;                        //找不到完全匹配
            for(int j = 1; j <= m; j++) if(vis_A[j]) ex_A[j] -= d;
            for(int j = 1; j <= n; j++)
            {
                if(vis_B[j]) ex_B[j] += d;
                else slack[j] -= d;
            }
        }
    }
    ll ans = 0;
    for(int i = 1; i <= n; i++)
    {
        if(match[i])              // 可以加 && val[match[i]][i] > -INF 去除一些匹配
            ans += val[match[i]][i];
    }
    return ans;
}

int main()
{
    scanf("%d %d", &n, &e);
    int x, y, w;
    m = n;
    for(int i = 1; i <= n; ++i)
        for(int j = 1; j <= n; ++j)
            val[i][j] = -INF;
    for(int i = 1; i <= e; ++i)
    {
        scanf("%d %d %d", &x, &y, &w);
        val[x][y] = w;
    }
    cout << km() << endl;
    for(int i = 1; i <= n; ++i)
        if(match[i]) printf("%d ", match[i]);
}

 

 

板子2:

#include <bits/stdc++.h>
using namespace std;
const int maxn = 555;
const int INF = 0x3f3f3f3f;
typedef long long ll;
int val[maxn][maxn], vis_a[maxn], vis_b[maxn], match[maxn];
int ex_a[maxn], ex_b[maxn], n, m, e;

bool dfs(int x)
{
    vis_a[x] = 1;
    for(int i = 1; i <= m; ++i)
    {
        if(!vis_b[i] && ex_a[x] + ex_b[i] == val[x][i])
        {
            vis_b[i] = 1;
            if(!match[i] || dfs(match[i]))
            {
                match[i] = x;
                return 1;
            }
        }
    }
    return 0;
}

ll km()
{
    memset(match, 0, sizeof(match));
    fill(ex_b + 1, ex_b + 1 + n, 0);
    for(int i = 1; i <= n; ++i)
    {
        ex_a[i] = val[i][1];
        for(int j = 2; j <= m; ++j)
            ex_a[i] = max(ex_a[i], val[i][j]);
    }
    for(int i = 1; i <= m; ++i)
    {
        while(1)
        {
            memset(vis_a + 1, 0, sizeof(vis_a[0]) * n);
            memset(vis_b + 1, 0, sizeof(vis_b[0]) * m);
            if(dfs(i)) break;
            int d = INF;
            for(int j = 1; j <= n; ++j)
                if(vis_a[j])
                    for(int k = 1; k <= m; ++k)
                        if(!vis_b[k]) d = min(d, ex_a[j] + ex_b[k] - val[j][k]);
            for(int j = 1; j <= n; ++j) if(vis_a[j]) ex_a[j] -= d;
            for(int j = 1; j <= m; ++j) if(vis_b[j]) ex_b[j] += d;
        }
    }
    ll ans = 0;
    for(int i = 1; i <= n; ++i)
    {
        if(match[i]) ans += val[match[i]][i];
    }
    return ans;
}

int main()
{
    scanf("%d %d", &n, &e);
    int x, y, w;
    m = n;
    for(int i = 1; i <= n; ++i)
        for(int j = 1; j <= n; ++j)
            val[i][j] = -INF;
    for(int i = 1; i <= e; ++i)
    {
        scanf("%d %d %d", &x, &y, &w);
        val[x][y] = w;
    }
    cout << km() << endl;
    for(int i = 1; i <= n; ++i)
        if(match[i]) printf("%d ", match[i]);
}

 

posted @ 2020-12-07 00:03  .Ivorelectra  阅读(180)  评论(0编辑  收藏  举报