描述
给你一个m x n (1 <= m, n <= 100)的矩阵A (0<=aij<=10000),要求在矩阵中选择一些数,要求每一行,每一列都至少选到了一个数,使得选出的数的和尽量的小。
输入
多组测试数据。首先是数据组数T
对于每组测试数据,第1行是两个正整数m, n,分别表示矩阵的行数和列数。
接下来的m行,每行n个整数,之间用一个空格分隔,表示矩阵A的元素。
输出
每组数据输出一行,表示选出的数的和的最小值。
数据范围
小数据:1 <= m, n <= 5
大数据:1 <= m, n <= 100
- 样例输入
-
2 3 3 1 2 3 3 1 2 2 3 1 5 5 1 2 3 4 5 5 1 2 3 4 4 5 1 2 3 3 4 5 1 2 2 3 4 5 1
- 样例输出
-
Case 1: 3 Case 2: 5
借鉴 : Taptree
FIRST:
这里补充几点:
1:最优方案至多包含n+m-1个点, 因为我们按顺序扫描这些点,每个点必然覆盖一个新的行或列,否则没有意义,这样第一个点会覆盖一个新的行+一个新的列,因此最优方案中的点数<=n+m-1。
2:此处是带上下界的最小费用可行流(只要可行流,不需要最大流)。上下界费用流不能直接用最短增广路算法进行。需要改构图,方法略烦,此处不表。
但是这题具有特殊性,
在构图的时候,对于
(2)从源S向所有Ri连边,流量限制为1 <= f <= n,费用设为0;
可以改成
S->Ri 流量0<=f<=1 费用-100000(本题中A[i][j]<=10000,此处取个较大的负数即可)
S->Ri 流量0<=f<=m 费用0
(4)从所有Cj向漏T连边,流量限制为 1 <= f <=m,费用设为0。
类似的
Cj->T 流量0<=f<=1 费用-100000(本题中A[i][j]<=10000,此处取个较大的负数即可)
Cj->T 流量0<=f<=n 费用0
最后求得的答案加上100000*(n+m)即可,这样只需要能处理负边的最小费用流算法即可。
SECOND:
大体思路是带上下界的费用流。
【构图】
(1)建立源S和漏T,对第i行设一个节点Ri,对第j列设一个节点Cj。
(2)从源S向所有Ri连边,流量限制为1 <= f <= n,费用设为0;
(3)从所有Ri向Cj连边,流量限制为 0 <=f <=1,费用设为A[i][j];
(4)从所有Cj向漏T连边,流量限制为 1 <= f <=m,费用设为0。
此图共V = m+n+2 = O(m+n)个顶点,E = O(m*n)条边。
___________________________________________________________________________
【求解步骤】
采用连续最短路算法进行增广。
(1)此网络中的初始可行流的费用就对应于只选 k0 = max{m, n}个数时的最小和;
(2)然后采用连续最短路算法进行增广,由于网络的特殊性,每次增广的流量是1。增广一次之后,新的费用就是只选k0+1个数时的最小和;
(3)不断重复步骤(2),直到流值等于m+n-1,就可以求得选取k0+2, k0+3, ... , m+n-1个元素时的最小和;
(4)在上面求得的所有最小和中,最小的一个就是题目要的答案。
___________________________________________________________________________
【复杂度分析】
空间复杂度:O(m*n);
时间复杂度:O(C * k * E),其中C是增广的次数,也即最大流的流量m+n-1。假定采用SPFA找最短路,每次找增广路的时间就是O(k*E),k可以看成常数2,E = O(m*n)为边数。
所以O(C * k * E) = O(k * (m+n-1) * m * n),当问题接近极限规模m=n=100时,k * C * m * n = 400w,再考虑到本图实际上是二分图,复杂度可能虚高,因此运行时间可接受。
(复杂度虚高的意思是,虽然复杂度看起来比较可怕,但算法实际运行起来很快。比如Hopcroft Karp算法复杂度是O(sqrt(V)*E),但是实际上甚至足够在1~2s内求解多达3000个点的二分图。)
___________________________________________________________________________
【附记】
1、增广到流量为m+n-1时即停止的原因是,最优解所选取的元素个数必然不超过m+n-1。详见 FIRST 。
2、复杂度确实有些虚高,代码实际只运行了160ms。
___________________________________________________________________________
【优化】
通过引入绝对值很大的负权边引导网络流优先走成可行流,可以避免求解上下界网络流重新构图。详见 FIRST 。
其实最后实现我也用了这个方法,因为懒得写重构图的代码了= =
___________________________________________________________________________
【对网络流的理解】
这里讲一下自己对网络流问题的理解吧:
重点是学会抽象的思考方式。
不管什么类型的网络流,求解的基本方法都是Ford–Fulkerson method(也有其他类别的算法,暂且按下不表),也即以下三大步:
1、找一个初始流;
2、在网络中找增广路,沿增广路对已经得到的网络流进行增广;
3、反复重复步骤2,直到网络中不存在增广路为止。
对于普通的网络流,步骤1被省掉了,因为零流就是一个可行流。
对于带上下界的网络流,步骤1通过 ”对重构图求最大流“ 实现,后面的步骤2,3都是在原图中做的。这里就不要去想第一步里 “对重构图求最大流” 的过程了,这只是一种数学手段,假如你能够用别的方法找到原图中的一个可行流,那么也是可以的。只要知道,并且相信步骤1中的操作能够完成 “找到初始可行流” 这个寒碜的目的就够了。
对于最小费用最大流,若采用连续最短路算法(还有其他最小费用流算法,此处按下不表),只要把步骤2的 “找增广路” 改成 “找最小费用的增广路” 即可。
对于带上下界的费用流,那就等于 带上下界的网络流+费用流,也即:步骤1按 “带上下界的网络流” 那样操作, 步骤2按 “最小费用最大流” 操作即可。
最后说这道题:其实完整的过程应该是一直寻找增广路,直到找到原图中的最大流(也就是选取全部m*n个元素)。不过,这道题又没问最大流,所以没有必要求出最大流,于是我让算法在半路上(也就是流值为m+n-1的时候)就停了下来。
这道题构图的核心思想是,图中(未加指明就是原图,而非重构图。再次声明:重构图只是为了找到原图中的一个可行流而做的一种数学处理,不是本质上必须的,不要去想有两张图,而是要把“在重构图中求最大流” 看成 “找到原图的一个可行流”,这才是本质)的一个流量为F的最小费用流,对应了题目中选取F个元素的最优解,全局的最优解必然满足选取的元素个数介于max{m,n}和m+n-1之间,因此图中的流量增加到m+n-1的时候停下来即可。
____________________________________________________________________________
【代码】
C和C++混编的,为了用STL里的queue(我真是太懒)。
没有注释,不过代码结构还算清晰吧,这也是很多人用的模板。
哦对,这份代码其实不是我回答里提到的算法,而是用了 FIRST 提到的那个优化:通过设置绝对值很大的负权边引导网络流,从而自然形成了初始的最小可行流,避免了复杂的重构图。(这也说明了抽象思考的重要性:重构图本身不是必须的,只是为了找到一个初始可行流而已;假如能用别的方法找到初始可行流,就不用重构图了。)
1 #include <stdio.h> 2 #include <iostream> 3 #include <string.h> 4 #include <algorithm> 5 #include <queue> 6 using namespace std; 7 #define SIZE 210 8 #define BOUND 100000 9 const int inf = 1<<29; 10 struct node{ 11 int s,t,f,w; 12 int next; 13 }edge[SIZE*SIZE+50]; 14 int head[SIZE]; 15 int mincost,tot; 16 int s,t,pre[SIZE]; 17 void add(int s,int t,int w,int f) 18 { 19 edge[tot].f=f; 20 edge[tot].w=w; 21 edge[tot].t=t; 22 edge[tot].s=s; 23 edge[tot].next=head[s]; 24 head[s]=tot++; 25 } 26 void addedge(int s,int t,int w,int f) 27 { 28 add(s,t,w,f); 29 add(t,s,-w,0); 30 } 31 int n; 32 bool spfa() 33 { 34 bool vis[SIZE]; 35 memset(vis,false,sizeof(vis)); 36 int d[SIZE]; 37 int i=n+2; 38 while(i--)d[i]= inf; 39 d[s]=0; 40 vis[s]=true; 41 queue<int> Q; 42 Q.push(s); 43 pre[s]=-1; 44 while(!Q.empty()) 45 { 46 int u=Q.front(); 47 Q.pop(); 48 vis[u]=false; 49 50 for(int i=head[u];i!=-1;i=edge[i].next) 51 { 52 if(edge[i].f>0&&d[u]+edge[i].w<d[edge[i].t]) 53 { 54 d[edge[i].t]=d[u]+edge[i].w; 55 pre[edge[i].t]=i; 56 if(!vis[edge[i].t]) 57 { 58 Q.push(edge[i].t); 59 vis[edge[i].t]=true; 60 } 61 } 62 } 63 } 64 if(d[t]==inf) 65 return false; 66 return true; 67 } 68 void solve() 69 { 70 for(int i=pre[t];i!=-1;i=pre[edge[i].s]) 71 { 72 edge[i].f-=1; 73 edge[i^1].f+=1; 74 mincost+=edge[i].w; 75 } 76 } 77 int work() 78 { 79 int m, n1; 80 scanf("%d %d",&m, &n1); 81 n = m + n1; 82 s=0;t=n+1; 83 tot=0; 84 mincost=0; 85 memset(head,-1,sizeof(head)); 86 for (int i=1; i<=m; ++i) 87 for (int j=m+1; j<=m+n1; ++j) 88 { 89 int c; 90 scanf("%d", &c); 91 addedge(i,j,c,1); 92 } 93 94 for (int i=1; i<=m; ++i) 95 { 96 addedge(s,i,0,n); 97 addedge(s,i,-BOUND,1); 98 } 99 100 for (int j=m+1; j<=m+n1; ++j) 101 { 102 addedge(j,t,0,m); 103 addedge(j,t,-BOUND,1); 104 } 105 106 int cnt = 0, mn = m > n1 ? m : n1; 107 while(spfa() && cnt<mn ) 108 { 109 solve(); 110 cnt++; 111 } 112 113 int minsum = mincost; 114 while(spfa() && cnt<m+n ) 115 { 116 solve(); 117 minsum = minsum < mincost ? minsum : mincost; 118 cnt++; 119 } 120 121 return minsum + BOUND * (m + n1); 122 } 123 124 int main() 125 { 126 int cases; 127 scanf("%d", &cases); 128 for (int i=1; i<=cases; i++) 129 { 130 printf("Case %d: %d\n", i, work() ); 131 } 132 133 return 0; 134 }
我的WA代码:
1 #include <stdio.h> 2 int D[100][100]; 3 int D2[100][100]; 4 bool id[100]; 5 int reserve[100]; 6 int column; 7 8 void hasVal(int *p, int i, int n , int (*M)[100], int &maxVal) 9 { 10 if(i == n) 11 { 12 int sum = 0; 13 for(int j = 0; j < n; ++j) 14 { 15 sum += M[j][p[j]]; 16 } 17 if(sum < maxVal) 18 maxVal = sum; 19 } 20 for(int begin = i; begin < n; ++begin) 21 { 22 int t = p[i]; 23 p[i] = p[begin]; 24 p[begin] = t; 25 hasVal(p, i+1, n, M , maxVal); 26 t = p[i]; 27 p[i] = p[begin]; 28 p[begin] = t; 29 } 30 } 31 32 int SumOfSqureMatrix(int (*M)[100] , int n) 33 { 34 int *p = new int[n]; 35 int maxVal = 0x7fffffff; 36 for(int i = 0; i < n; ++i) 37 p[i] = i; 38 hasVal(p, 0, n, M, maxVal); 39 delete[] p; 40 p = NULL; 41 return maxVal; 42 } 43 44 int getSum(bool id[], int m , int n) 45 { 46 column = 0; 47 for(int j = 0,k2 = 0; j < n; ++j){ 48 if(id[j]){ 49 for(int i = 0; i < m; ++i) 50 { 51 D2[i][column] = D[i][j]; 52 } 53 ++column; 54 }else{ 55 int s = 0; 56 int tem = D[s++][j]; 57 while(s < m){ 58 if(D[s][j] < tem) 59 tem = D[s][j]; 60 s++; 61 } 62 reserve[k2++] = tem; 63 } 64 } 65 int sum = SumOfSqureMatrix(D2, m); 66 for(int t = 0;t < n-m; t++) 67 sum += reserve[t]; 68 return sum; 69 } 70 71 void permutation(bool id[], int start, int m, int n, int &sum) 72 { 73 if(start == m) 74 { 75 int temp = getSum(id, m, n); 76 for(int i = 0; i < n; ++i) 77 printf("%d ", id[i]); 78 printf("\n"); 79 if(temp < sum) sum = temp; 80 return; 81 } 82 for(int begin = start; begin < n; ++begin) 83 { 84 bool tem = id[begin]; 85 id[begin] = id[start]; 86 id[start] = tem; 87 permutation(id, start + 1, m, n, sum); 88 89 tem = id[begin]; 90 id[begin] = id[start]; 91 id[start] = tem; 92 93 } 94 } 95 96 int SumOfMatrix(int m, int n) 97 { 98 if(m > n) return 0; 99 int sum = 0x7fffffff; 100 for(int i = 0; i < n; ++i) 101 id[i] = 0; 102 for(int i = 0; i < m; ++i) 103 id[i] = 1; 104 if(m == 1 || n == 1) return getSum(id, m, n); 105 permutation(id, 0, m, n, sum); 106 return sum; 107 } 108 109 void reverse(int m, int n) 110 { 111 int **tem = new int*[n]; 112 for(int i = 0; i < n; ++i) 113 { 114 tem[i] = new int[m]; 115 for(int j = 0; j < m; ++j) 116 tem[i][j] = D[j][i]; 117 } 118 for(int i = 0; i < n; ++i) 119 for(int j = 0; j < m; ++j) 120 D[i][j] = tem[i][j]; 121 for(int i = 0; i < n; ++i) 122 { 123 delete[] tem[i]; 124 tem[i] = NULL; 125 } 126 delete[] tem; 127 tem = NULL; 128 } 129 130 int main() 131 { 132 int T; 133 scanf("%d", &T); 134 for(int k = 0; k < T; ++k) 135 { 136 int m , n; 137 scanf("%d%d", &m, &n); 138 for(int i = 0; i < m; ++i) 139 for(int j = 0; j < n; ++j) 140 scanf("%d", &D[i][j]); 141 if(m == n) 142 { 143 printf("Case %d: %d\n", k+1, SumOfSqureMatrix(D, n)); 144 continue; 145 } 146 else if(m > n) 147 { 148 reverse(m, n); 149 m ^= n; 150 n ^= m; 151 m ^= n; 152 } 153 printf("Case %d: %d\n", k+1, SumOfMatrix(m, n)); 154 } 155 return 0; 156 }
错误样例:
1
3 3
9 1 1
1 9 9
1 9 9
错误原因:以为最多选取 max(M,N) 个数。
适合情况:最多选取 max(M,N) 个数。