矩阵连乘求解优化
前言
从旭东的博客 看到一篇博文:矩阵连乘最优结合 动态规划求解,挺有意思的,这里做个转载【略改动】。
问题
矩阵乘法是这样的,比如\[ A_{ab} B_{bc} = C_{ac} \]
两个矩阵,一个a行,一个c列,行列乘法次数为a*c。一行乘以一列得到C中的一个元素,乘法次数为b,故矩阵乘法AB需要的乘法次数是a*c*b。
我们把b称为接口,那么矩阵连乘的次数是乘积的尺寸乘以中间的接口。中间的接口是矩阵高度,如果尽快能把长得高的矩阵通过乘法消化掉,这些大接口发生作用的机会就少,最终乘法次数就少了。
采用一维数组存储各矩阵高度。每次遍历找到最大值,和左边的矩阵相乘即可,直到最后只剩下一个矩阵。
输入参数是数组arr:存储各矩阵高度,最后一个元素为最后一个矩阵的列数,这个数组包含矩阵连乘表达式的所有信息。
loopTimes 是循环次数,循环次数为矩阵个数减1。
arrMaxId函数用于获取数组最大值索引,跳过第一个矩阵,因为第一个矩阵左边没有其他矩阵作为乘数。
由于最后要输出计算式,我们为每个矩阵设置一个名称,这个名称随着乘法的进行发生变化。最终会留下第一个矩阵,其名称就是最终运算式。
代码如下
int matrixMulTimes(vector<int> &arr) { int maxId; int mulTimes = 0; int pre = 0, next=0; string str=""; vector<string> matrixName(arr.size() - 1); string mulLeftStr, mulRightStr; int loops = 0; int loopTimes = arr.size() - 2; while (loops++ < loopTimes) { maxId = arrMaxId(arr, 1, arr.size() - 2); pre = maxId - 1; next = maxId + 1; while (arr[pre--] == -1); while (arr[next++] == -1); mulTimes += arr[++pre] * arr[maxId] * arr[--next]; arr[maxId] = -1; mulLeftStr = matrixName[pre] == "" ? string(1, 'A' + pre) : matrixName[pre]; mulRightStr = matrixName[maxId] == "" ? string(1, 'A' + maxId) : matrixName[maxId]; matrixName[pre] = "(" + mulLeftStr +"*" + mulRightStr + ")"; } cout << matrixName[0] << endl; return mulTimes; }
函数arrMaxId代码如下
int arrMaxId(vector<int> &arr, int begin, int end) { if (arr.size() == 0 || arr.size()<end) { return -1; } int maxId = begin; for (int i = begin+1; i <= end; ++i) { if (arr[i] > arr[maxId]) { maxId = i; } } return maxId; }
上面只能是一个近似最优解,因为每次消去最高的矩阵,可能参与乘法的另一个矩阵也比较高,导致其存活更久更多地参与到运算中去,最后得不偿失。
如果非要得到最优解,可以运算并存储所有子式的运算量,自底向上,直到算出整个乘法算式。这个叫做动态规划,不过复杂度比较高。如果设置1000个矩阵相乘,动态规划可能算十几分钟都不一定有结果,但第一个算法几秒钟就能给出答案。
首先确定跨度,再确定起点,构成一个二级嵌套循环,通过起点的平移确定子式。子式确定后,再在内部嵌套一级循环,遍历子式的可能的分割点,并保存子式的最少计算次数及对应分割点。最后我们得到最大的子式,也就是连乘本身的最少运算次数及分割方案。
代码如下:
//根据记录的分割点,生成最后的矩阵相乘表达式 string make_result(vector<vector<int> > &points, int t1, int t2) { if (t1 == t2) return string(1, 'A' + t1 - 1); int split = points[t1][t2]; return "(" + make_result(points, t1, split) + "*" + make_result(points, split + 1, t2) + ")"; } int calculate_M(vector<int> &arr) { int matrixNum = arr.size() - 1; vector<vector<int> > num(matrixNum + 1, vector<int>(matrixNum + 1)); vector<vector<int> > points(matrixNum + 1, vector<int>(matrixNum + 1)); int span; int start; int end; int spiltPoint; int mulTimes; int rows, columns, interfaces; for (span = 1; span < arr.size() - 1; span++) { for (start = 1; start + span < arr.size(); start++) { end = start + span; num[start][end] = INT_MAX; for (spiltPoint = start; spiltPoint < end; spiltPoint++) { rows = arr[start - 1]; columns = arr[end]; interfaces = arr[spiltPoint]; mulTimes = num[start][spiltPoint] + num[spiltPoint + 1][end] + rows * interfaces * columns; if (mulTimes < num[start][end]) { points[start][end] = spiltPoint; num[start][end] = mulTimes; } } } } cout << make_result(points, 1, matrixNum) << "\t最少乘法次数为:" << num[1][matrixNum] << endl; return 0; }
代码中用到的一些知识
C++提供模版类string,其中一个构造方法可将字符转化为字符串。如 string(1, 'A'+1),第一个参数是源字符延拓次数,这个构造函数将‘B’转化为"B"。