KM算法及其优化的学习笔记&&bzoj2539: [Ctsc2000]丘比特的烦恼
感谢 http://www.cnblogs.com/vongang/archive/2012/04/28/2475731.html
这篇blog里提供了3个链接……基本上很明白地把KM算法是啥讲清楚了
然而n^4的KM好像并没有什么卵用啊……所以不得不学n^3的
我看了一下各种,大部分blog里写的声称是n^3的KM,其实貌似都是n^4的(包括上面的链接以及上面链接里提供的链接)
这是因为他们有个共同点
他们虽然用slack数的优化组避免了暴力枚举d所消耗的时间,但由于一次增广是n^2的,所以拖慢了复杂度
那么怎么解决这个问题呢?
尛焱轟告诉我们,用bfs增广的KM是n^3的,用dfs增广的KM是n^4的
尛焱轟还告诉我们,可以去UOJ上拉个板子,都是n^3的
于是窝就拉了个策爷的板子来看(然后改了几个变量名,背下来,就学完了……)
为什么dfs会成为算法时间复杂度减小的瓶颈呢?
我们发现,每更新顶标,就要重新从当前点开始dfs找一遍增广路,有很多冗余的操作
实际上,更新完顶标之后,交错树只会增加新的点
那么窝萌不妨用bfs来增广,每次修改完顶标,把没访问到的右侧点的slack值也相应地减去d,那么slack值为0就相当于多了一条可行边,就相当于能够访问到新的节点,也就可以继续找增广路了
这样再把新的点加进队列,就避免了dfs增广的版本中的冗余操作
这样就发挥了slack这一优化的优势,复杂度自然降到O(n^3)
然后窝来帖一下窝的代码(uoj#80 二分图最大权匹配)
#include <iostream> #include <cstdio> #include <cmath> #include <cstring> #include <cstdlib> #include <algorithm> #define ll long long #define N 405 #define INF (1LL<<60) using namespace std; inline int read(){ int ret=0;char ch=getchar(); while (ch<'0'||ch>'9') ch=getchar(); while ('0'<=ch&&ch<='9'){ ret=ret*10-48+ch; ch=getchar(); } return ret; } int n,fx[N],fy[N],prev[N]; ll g[N][N],A[N],B[N],slk[N]; bool visx[N],visy[N]; int q[N],qh,qt; void aug(int v){ if (!v) return; fy[v]=prev[v]; aug(fx[prev[v]]); fx[fy[v]]=v; } void bfs_KM(int _s){ memset(visx,0,sizeof(visx)); memset(visy,0,sizeof(visy)); memset(slk,127,sizeof(slk)); qh=qt=0; q[++qt]=_s; for (;;){ while (qh<qt){ int u=q[++qh]; visx[u]=1; for (int v=1;v<=n;++v)if (!visy[v]){ if (A[u]+B[v]==g[u][v]){ visy[v]=1; prev[v]=u; if (!fy[v]){aug(v);return;} q[++qt]=fy[v]; continue; } if (slk[v]>A[u]+B[v]-g[u][v]){ slk[v]=A[u]+B[v]-g[u][v]; prev[v]=u; } } } ll d=INF; for (int i=1;i<=n;++i) if (!visy[i]) d=min(d,slk[i]); for (int i=1;i<=n;++i){ if (visx[i]) A[i]-=d; if (visy[i]) B[i]+=d; else slk[i]-=d; } for (int v=1;v<=n;++v)if (!visy[v]&&!slk[v]){ visy[v]=1; if (!fy[v]){aug(v);return;} q[++qt]=fy[v]; } } } ll KM(){ for (int i=1;i<=n;++i){ A[i]=-INF;B[i]=0; for (int j=1;j<=n;++j) A[i]=max(A[i],g[i][j]); } memset(fx,0,sizeof(fx)); memset(fy,0,sizeof(fy)); for (int i=1;i<=n;++i) bfs_KM(i); ll ret=0; for (int i=1;i<=n;++i) ret+=A[i]+B[i]; return ret; } bool e0[N][N]; int main(){ int nl=read(),nr=read(); memset(g,0,sizeof(g)); memset(e0,0,sizeof(e0)); for (int m0=read();m0;--m0){ int u=read(),v=read(); g[u][v]=read(); e0[u][v]=1; } n=max(nl,nr); ll ans=KM(); printf("%lld\n",ans); for (int i=1;i<=nl;++i) printf("%d ",e0[i][fx[i]]*fx[i]); puts(""); return 0; }
感谢尛焱轟神犇的指点
感谢jcvb神犇的代码
感谢上面的那篇blog以及那篇blog里的链接
更新一下,窝在丘比特的烦恼(KM模板题)里把KM封装了一下,方便大(我)家(拖)学(板)习(子)
顺便提一下此题的几个坑爹的地方:1坐标可能为负,2姓名无视大小写,3必须连n对情侣(也就是说不能连的边权必须赋为-INF)
#include <iostream> #include <cstdio> #include <cstring> #include <cmath> #include <cstdlib> #include <algorithm> #include <map> #include <string> #define N 32 #define INF (1e9) using namespace std; inline int read(){ int ret=0;char ch=getchar(); bool flag=0; while (ch<'0'||ch>'9'){ flag=ch=='-'; ch=getchar(); } while ('0'<=ch&&ch<='9'){ ret=ret*10-48+ch; ch=getchar(); } return flag?-ret:ret; } struct KM{ int n; int g[N][N],slk[N],A[N],B[N]; int prev[N],fx[N],fy[N]; bool visx[N],visy[N]; int q[N],qh,qt; void clear(int _n){ n=_n;memset(g,0,sizeof(g)); } void AddEdge(int u,int v,int w){ g[u][v]=w; } void aug(int v){ if (!v) return; fy[v]=prev[v]; aug(fx[fy[v]]); fx[fy[v]]=v; } void bfs(int _s){ memset(visx,0,sizeof(visx)); memset(visy,0,sizeof(visy)); memset(slk,127,sizeof(slk)); qh=qt=0;q[++qt]=_s; for (;;){ while (qh<qt){ int u=q[++qh]; visx[u]=1; for (int v=1;v<=n;++v)if (!visy[v]){ if (A[u]+B[v]==g[u][v]){ visy[v]=1; prev[v]=u; if (!fy[v]){aug(v);return;} q[++qt]=fy[v]; } else if (slk[v]>A[u]+B[v]-g[u][v]){ slk[v]=A[u]+B[v]-g[u][v]; prev[v]=u; } } } int d=INF; for (int i=1;i<=n;++i)if (!visy[i])d=min(d,slk[i]); for (int i=1;i<=n;++i){ if (visx[i]) A[i]-=d; if (visy[i]) B[i]+=d; else slk[i]-=d; } for (int v=1;v<=n;++v) if (!visy[v]&&!slk[v]){ visy[v]=1; if (!fy[v]){aug(v);return;} q[++qt]=fy[v]; } } } int solve(){ memset(A,128,sizeof(A)); memset(B,0,sizeof(B)); memset(fx,0,sizeof(fx)); memset(fy,0,sizeof(fy)); for (int i=1;i<=n;++i) for (int j=1;j<=n;++j) A[i]=max(A[i],g[i][j]); for (int i=1;i<=n;++i) bfs(i); int ret=0; for (int i=1;i<=n;++i) ret+=A[i]+B[i]; return ret; } } km; int n; map<string,int> id; int x[N*2],y[N*2],lmt; void Upper(string &s){ int l=s.length(); for (int i=0;i<l;++i)if (s[i]>'Z') s[i]-=32; } int main(){ string tmp; lmt=read();n=read(); for (int i=1;i<=2*n;++i){ x[i]=read();y[i]=read(); cin>>tmp;Upper(tmp);id[tmp]=i; } km.clear(n); for (int i=1;i<=n;++i) for (int j=1;j<=n;++j) km.AddEdge(i,j,1); for (cin>>tmp;tmp!="End";cin>>tmp){ Upper(tmp); int u=id[tmp],v; cin>>tmp;Upper(tmp);v=id[tmp]; if (u>v) swap(u,v); km.AddEdge(u,v-n,read()); } for (int i=1;i<=n;++i) for (int j=n+1;j<=2*n;++j){ bool found=(x[i]-x[j])*(x[i]-x[j])+(y[i]-y[j])*(y[i]-y[j])>lmt*lmt; for (int k=1;k<=2*n&&!found;++k){ int A=y[k]-y[j],B=x[k]-x[j],C=y[k]-y[i],D=x[k]-x[i]; found=A*D==B*C&&(A*C<0||B*D<0); } if (found) km.AddEdge(i,j-n,-1e7); } printf("%d\n",km.solve()); return 0; }