KM算法
二分图带权匹配
基本概念
交错树:DFS过程中, 所有访问过的节点以及这些节点形成的边构成的树。(即寻找增广路过程中形成的树。)
顶标:每个顶点赋予的顶点标记值。满足:\(a_i + b_j >= w(i, j)\);
相等子图:满足 \(a_i + b_j == w(i, j)\) 的边构成的子图。
定理:当相等子图中存在完美匹配时, 这个完美匹配就是二分图的带权最大匹配。
证明:显而易见。
KM算法
算法流程:
- 初始化可行顶标的值 (设定la,lb的初始值)。
- 用匈牙利算法寻找相等子图的完备匹配。
- 若未找到增广路则修改可行顶标的值。
- 重复(2)(3)直到找到相等子图的完备匹配为止。
修改方法:对于交错树中所有左边顶点,减去一个常数delta, 对于交错树中所有的右顶点,加上这个常数delta。其中delta为所有不在交错树中 \(a_i + b_j - w(i, j)\) 的最小值。
模板题
【题意】:就是模板题。
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int maxn = 310;
const int inf = 0x3f3f3f3f;
const double eps = 1e-6, pi = acos(-1);
int gcd(int a,int b) { return b == 0 ? a : gcd(b, a % b); }
int n, match[maxn];
bool va[maxn], vb[maxn];
LL upd[maxn], w[maxn][maxn], delta;
LL la[maxn], lb[maxn];
bool dfs(int u)
{
va[u] = 1;
for(int v = 1; v <= n; ++v){
if(vb[v]) continue;
if(la[u] + lb[v] - w[u][v] == 0){
vb[v] = 1;
if(!match[v] || dfs(match[v])){
match[v] = u;
return 1;
}
}
else{
upd[v] = min(upd[v], la[u] + lb[v] - w[u][v]);
}
}
return 0;
}
LL km()
{
for(int i = 1; i <= n; ++i){
la[i] = -inf;
lb[i] = 0;
for(int j = 1; j <= n; ++j){
la[i] = max(la[i], w[i][j]);
match[j] = 0;
}
}
for(int i = 1; i <= n; ++i){
while(true){
memset(va, 0, sizeof(va));
memset(vb, 0, sizeof(vb));
memset(upd, inf, sizeof(upd));
if(dfs(i)) break;
delta = inf;
for(int j = 1; j <= n; ++j){
if(!vb[j]) delta = min(delta, upd[j]);
}
for(int j = 1; j <= n; ++j){
if(va[j]) la[j] -= delta;
if(vb[j]) lb[j] += delta;
}
}
}
LL ans = 0;
for(int i = 1; i <= n; ++i) ans += w[match[i]][i];
return ans;
}
int main()
{
while(~scanf("%d", &n)){
for(int i = 1; i <= n; ++i){
for(int j = 1; j <= n; ++j){
scanf("%lld", &w[i][j]);
}
}
printf("%lld\n", km());
}
system("pause");
}
蚂蚁
【题意】:2 * n 个点, 黑白点各n个,怎么匹配黑白点使得每一对匹配之间的连线没有交点。
【思路】:随便举一个例子画个图形会发现:如果两个匹配之间没有交点,那么他们之间的距离之和一定最小。可以用三角形两边之和大于第三边来反证它。
注意:
- 判断实数之间的大小和是否相等一定要加 \(eps!!\)
- 用km算法求最小匹配时可以将所有边的权值取相反数,这样变成了最大匹配,输出 \(-ans\) 即可。
- \(dfs\) 里面求 \(delta\) 常数较大,可以放在 \(km\)函数中求并用 \(va[i]\ \ \&\& \ \ !vb[j]\) 来剪枝。
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int maxn = 110;
const int inf = 0x3f3f3f3f;
const double eps = 1e-6, pi = acos(-1);
int gcd(int a,int b) { return b == 0 ? a : gcd(b, a % b); }
struct node{
int x, y;
}a[maxn], b[maxn];
int va[maxn], vb[maxn], n, match[maxn];
double w[maxn][maxn];
double la[maxn], lb[maxn];
int ans[maxn];
inline double dis(node a, node b) { return sqrt((a.x - b.x) * (a.x - b.x) + (a.y - b.y) * (a.y - b.y)); }
bool dfs(int u)
{
va[u] = 1;
for(int v = 1; v <= n; ++v){
if(vb[v]) continue;
if(fabs(la[u] + lb[v] - w[u][v]) < eps){ //实数比大小一定要加eps!!!
vb[v] = 1;
if(!match[v] || dfs(match[v])){
match[v] = u;
return 1;
}
}
}
return 0;
}
void km()
{
for(int i = 1; i <= n; ++i){
la[i] = -inf;
lb[i] = 0;
for(int j = 1; j <= n; ++j){
la[i] = max(la[i], w[i][j]);
}
}
for(int i = 1; i <= n; ++i){
while(true){
memset(va, 0, sizeof(va));
memset(vb, 0, sizeof(vb));
if(dfs(i)) break;
double delta = inf;
for(int j = 1; j <= n; ++j){
if(!va[j]) continue;
for(int k = 1; k <= n; ++k){
if(vb[k]) continue;
delta = min(delta, la[j] + lb[k] - w[j][k]);
}
}
for(int j = 1; j <= n; ++j){
if(va[j]) la[j] -= delta;
if(vb[j]) lb[j] += delta;
}
}
}
for(int i = 1; i <= n; ++i){
ans[match[i]] = i;
}
for(int i = 1; i <= n; ++i){
printf("%d\n", ans[i]);
}
}
int main()
{
scanf("%d", &n);
for(int i = 1; i <= n; ++i){ //黑点
scanf("%d %d", &b[i].x, &b[i].y);
}
for(int i = 1; i <= n; ++i){
scanf("%d %d", &a[i].x, &a[i].y);
}
for(int i = 1; i <= n; ++i){
for(int j = 1; j <= n; ++j){
w[i][j] = -dis(b[i], a[j]);
}
}
km();
system("pause");
}