poj 3686 KM算法
有N个工件要在M个机器上加工,有一个N*M的矩阵描述其加工时间。
同一时间内每个机器只能加工一个工件,问加工完所有工件后,使得平均加工时间最小(等待的时间+加工的时间)。
假设某个机器处理了k个玩具,时间分别为a1,a2…..,ak
那么该机器耗费的时间为a1+(a1+a2)+(a1+a2+a3).......(a1+a2+...ak)
即a1*k + a2 * (k - 1) + a3 * (k - 2).... + ak
ai玩具在某个机器上倒数第k个处理,所耗费全局的时间为ai*k
对每个机器,最多可以处理n个玩具,拆成n个点,1~n分别代表某个玩具在这个机器上倒数第几个被加工的,对于每个玩具i,机器j中拆的每个点k,连接一条w[i][j]*k权值的边
1 #include <iostream> 2 #include <cstring> 3 #include <cstdio> 4 #include <algorithm> 5 #include <cmath> 6 7 using namespace std; 8 9 #define MAXN 55 10 #define inf 0x7ffffff 11 12 int w[MAXN][2555]; 13 int lx[MAXN],ly[2555]; 14 int linky[2555]; 15 int visx[MAXN],visy[2555]; 16 int slack[2555]; 17 int nx,ny; 18 int n,m; 19 20 bool find(int x) 21 { 22 visx[x]=1; 23 for(int y=1;y<=ny;y++) 24 { 25 if(visy[y]) continue; 26 int t=lx[x]+ly[y]-w[x][y]; 27 if(t==0) 28 { 29 visy[y]=1; 30 if(linky[y] ==-1 || find(linky[y])) 31 { 32 linky[y]=x; 33 return true;//找到增广路 34 } 35 } 36 else if(slack[y] > t) 37 slack[y]=t; 38 } 39 return false; 40 } 41 42 int KM() 43 { 44 memset(linky,-1,sizeof(linky)); 45 memset(ly,0,sizeof(ly)); 46 for(int i=1;i<=nx;i++) 47 { 48 lx[i]=-inf; 49 for(int j=1;j<=ny;j++) 50 if(w[i][j] >lx[i]) 51 lx[i]=w[i][j]; 52 } 53 for(int x=1;x<=nx;x++) 54 { 55 for(int i=1;i<=ny;i++) 56 slack[i]=inf; 57 while(1) 58 { 59 memset(visx,0,sizeof(visx)); 60 memset(visy,0,sizeof(visy)); 61 if(find(x)) break; 62 int d=inf; 63 for(int i=1;i<=ny;i++) 64 if(!visy[i] && d>slack[i]) 65 d=slack[i]; 66 for(int i=1;i<=nx;i++) 67 if(visx[i]) 68 lx[i]-=d; 69 for(int i=1;i<=ny;i++) 70 if(visy[i]) 71 ly[i]+=d; 72 else 73 slack[i]-=d; 74 } 75 } 76 int ans=0; 77 for(int i=1;i<=ny;i++) 78 if(linky[i] >-1) 79 ans+=w[linky[i]][i]; 80 return -ans; 81 } 82 83 void init() 84 { 85 scanf("%d%d",&n,&m); 86 nx=n; 87 ny=n*m; 88 int cost; 89 for(int i=1;i<=n;i++) 90 { 91 int cnt=1; 92 for(int j=1;j<=m;j++) 93 { 94 scanf("%d",&cost); 95 for(int k=1;k<=n;k++) 96 { 97 w[i][cnt++]=-cost*k; 98 } 99 } 100 } 101 } 102 103 int main() 104 { 105 int t; 106 scanf("%d",&t); 107 while(t--) 108 { 109 init(); 110 double ans=1.0*KM()/n; 111 printf("%.6f\n",ans); 112 } 113 return 0; 114 }