KM bfs写法
KM bfs写法
2018astar资格赛的第三题整数规划。
把\(x, y\)看成二分图两边的顶标,\(a_{ij}\)就是二分图的边权,整道题其实就是求二分图的最大权匹配。
然后打了个\(dfs\)的\(KM\),\(TLE\)了,后来听别人说要用\(bfs\)的写法,因为那个才是真正的\(O(n^3)\),\(dfs\)的写法最坏情况还是\(O(n^4)\)。
原理是一样的,只不过\(bfs\)有一点点像迭代,每一次也只是搜\(diff=0\)的情况,而且右边的点只会搜索一次(或者说是左边的点只会搜索一次,即左边的每个点只会进队一次),用\(pre\)记住当前的交错路径,找到未匹配的就可以沿交错路径进行修改。
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int maxn=210;
const LL inf=1LL<<60;
int n;
namespace KM
{
int n;
LL mat[maxn][maxn]; //边权
int matcha[maxn], matchb[maxn]; //左边的点匹配的右边点;右边的点匹配的左边点
LL marka[maxn], markb[maxn]; //左顶标;右顶标
LL slack[maxn]; //松弛数组
bool visa[maxn], visb[maxn]; //访问标记
int head, tail;
int q[maxn], pre[maxn]; //队列;交错路径
bool check(int cur)
{
visb[cur]=true; //标记cur已搜索
if (matchb[cur]) //已匹配,即当前匹配失败
{
if (!visa[matchb[cur]]) //匹配的点是否已进队
{
q[++tail]=matchb[cur];
visa[matchb[cur]]=true;
}
return false;
}
//未匹配,即当前匹配成功,沿交错路径进行匹配
while (cur)
swap(cur, matcha[matchb[cur]=pre[cur]]);
return true;
}
void bfs(int start)
{
fill(visa, visa+1+n, false);
fill(visb, visb+1+n, false);
fill(slack, slack+1+n, inf);
head=tail=1;
q[1]=start;
visa[start]=true;
while (1)
{
while (head<=tail)
{
int cur=q[head++];
for (int i=1; i<=n; ++i)
{
LL diff=marka[cur]+markb[i]-mat[cur][i];
if (!visb[i] && diff<=slack[i]) //visb=true说明已搜索,无需更新slack和pre,也是保证pre的正确性
{
slack[i]=diff;
pre[i]=cur;
if (diff==0) //diff=0,可以尝试匹配
if (check(i)) return; //匹配成功可直接返回
}
}
}
LL delta=inf;
for (int i=1; i<=n; ++i)
if (!visb[i] && slack[i]) delta=min(slack[i], delta);
for (int i=1; i<=n; ++i) //松弛
{
if (visa[i]) marka[i]-=delta;
if (visb[i]) markb[i]+=delta;
else slack[i]-=delta; //维护slack的正确性(参考diff的计算及marka,markb的变化)
}
head=1, tail=0;
for (int i=1; i<=n; ++i)
if (!visb[i] && !slack[i] && check(i)) return;
//松弛后尝试匹配diff=0的点。
}
}
void solve()
{
fill(matcha, matcha+1+n, 0);
fill(matchb, matchb+1+n, 0);
fill(markb, markb+1+n, 0);
for (int i=1; i<=n; ++i)
{
marka[i]=0;
for (int j=1; j<=n; ++j)
marka[i]=max(marka[i], mat[i][j]);
}
for (int i=1; i<=n; ++i) bfs(i);
}
}
void read()
{
scanf("%d", &n);
KM::n=n;
for (int i=1; i<=n; ++i)
for (int j=1; j<=n; ++j)
{
int x;
scanf("%d", &x);
KM::mat[i][j]=-x;
}
}
void solve()
{
KM::solve();
LL ans=0;
for (int i=1; i<=n; ++i)
ans+=KM::marka[i]+KM::markb[i];
printf("%lld\n", -ans);
}
int main()
{
int casesum;
scanf("%d", &casesum);
for (int i=1; i<=casesum; ++i)
{
printf("Case #%d: ", i);
read();
solve();
}
return 0;
}