矩阵连乘问题

矩阵连乘是经典的DP问题。n个矩阵连乘会因为乘的次序问题导致效率差异。

这个问题中比较难想的是计算m[][]时的次序,必须保证计算m[i][j]查找m[i][k]和m[k+1][j]时这两个已经计算出来了。

这里是根据王晓东的算法书里的想法写的,他用的是Java,必须提醒的就是在用C++时一个蛋疼的数组越界问题。

int *p = new int[5];

此时p[5] = 10; 就是越界的,但C++为了效率上的考虑华丽地无视了这个越界,编译器并没有义务指出它。

更bt的是cout << p[5] << endl; 是正确的,能输出10。但这程序指不定什么时候就诡异得挂了。

所以n个矩阵连乘中开的数组最好干脆用vector<int>写。如果用new的话,要注意其实我们是要访问p[n],m[n][n],s[n][n]的,所以要开n+1的数组。看到网上的有些版本没有注意这点。

 

程序的测试数据为:

A1 30x35;  A2 35x15;  A3 15x5;  A4 5x10;  A5 10x20;  A6 20x25

 1 #include <iostream>
 2 
 3 using namespace std;
 4 
 5 void matrixChain(int *p, int **m, int **s, int n) //p记录矩阵的行列,s记录从哪断开
 6 {
 7     for (int i=1; i<=n; ++i)
 8     {
 9         m[i][i] = 0;
10     }
11 
12     for (int r=2; r<=n; ++r)
13     {
14         for (int i=1; i<=n-r+1; ++i)
15         {
16             int j = r + i - 1;
17             m[i][j] = m[i][i] + m[i+1][j] + p[i-1]*p[i]*p[j];
18             s[i][j] = i;
19 
20             for (int k=i+1; k<j; ++k)
21             {
22                 int temp = m[i][k] + m[k+1][j] + p[i-1]*p[k]*p[j];
23                 if (temp < m[i][j])
24                 {
25                     m[i][j] = temp;
26                     s[i][j] = k;
27                 }
28             }
29         }
30     }
31 }
32 
33 void traceback(int **s, int i, int j) //输出A[i:j]的最优计算次序
34 {
35     if(i == j)
36     {
37         cout << "A" <<i;
38     }
39     else if (i+1 == j)
40     {
41         cout << "(A"<< i << "A" << j <<")";
42     }
43     else
44     {
45         cout << "(";
46         traceback(s, i, s[i][j]);
47         traceback(s, s[i][j]+1, j);
48         cout << ")";
49     }
50 }
51 
52 int main()
53 {
54 
55     int n = 6;
56     int *p = new int[n+1];
57 
58 
59     p[0]=30;
60     p[1]=35;
61     p[2]=15;
62     p[3]=5;
63     p[4]=10;
64     p[5]=20;
65     p[6]=25;
66 
67 
68     int **m, **s;
69     m = new int *[n+1];
70     s = new int *[n+1];
71     for(int i=0; i<=n; ++i)
72     {
73         m[i] = new int[n+1];
74         s[i] = new int[n+1];
75     }
76 
77 
78     matrixChain(p, m, s, n);
79     traceback(s, 1, n);
80 
81     for(int i=0; i<=n; ++i)
82     {
83         delete []m[i];
84         m[i] = 0;
85         delete []s[i];
86         s[i] = 0;
87     }
88     delete []m;
89     m = 0;
90     delete []s;
91     s = 0;
92     delete []p;
93     p = 0;
94 
95     return 0;
96 }

输出:

 

posted @ 2012-04-24 11:20  漂木  阅读(447)  评论(0编辑  收藏  举报