dp方法论——由矩阵相乘问题学习dp解题思路
导语
刷过一些算法题,就会十分珍惜“方法论”这种东西。Leetcode上只有题目、讨论和答案,没有方法论。往往答案看起来十分切中要害,但是从看题目到得到思路的那一段,就是绕不过去。楼主有段时间曾把这个过程归结于智商和灵感的结合,直到有天为了搞懂Leetcode上一位老兄的题型总结,花两天时间学习了回溯法,突然有种惊为天人的感觉——原来真正掌握一个算法是应该触类旁通的,而不是将题中一个细节换掉就又成了新题……
掌握方法论绝对是一种很爽的感觉。看起来好像很花费时间,其实是一种“因为慢,所以快”的方法。以前可能你学习一个dp题目要大半天;当你花了半个周时间,学会了dp的套路,你会发现,有些medium的dp题甚至不需要半个小时就能做完,而且从头到尾不需提示,全靠自己!
方法论
那么,怎么从一个看起来毫无头绪的问题出发,找到解题的思路并用dp将问题解出来呢?本文以矩阵相乘问题为例,给出dp问题的一般解题思路。
当然,按照思路解题的前提是你已经知道这道题要用dp去解,如何确定一个问题可以用dp去解,则是下一篇要讨论的话题。
下面就是动态规划的一般解题思路:
- 分析最优解的特征。
- 递归地定义最优解的值。
- 计算最优解的值。
- 根据计算好的信息构造最优解。
看起来非常抽象是吧?在这里不需要完全理解。等你看完全文再回来,保你会有不一样的感受。
矩阵相乘问题
问题
这是一个看起来可能有点抽象的数学问题,但请你耐心往下看。当你看完解法时,你会惊异于动态规划的魔力。
题目:给出一个由n个矩阵组成的矩阵链<A1,A2,...,An>,矩阵Ai的秩为pi-1×pi。将A1A2...An这个乘积全括号化,使得计算这个乘积所需要的的标量乘法最少。
全括号化是以一种递归的形式定义的:
一个全括号化的乘积只有两种可能:一是一个单个矩阵;二是两个全括号化的乘积的乘积。
天啦也太绕了,举个例子吧。对于矩阵链<A1,A2,A3,A4>的乘积,共有五种全括号化的方法:
(A1(A2(A3A4))),
(A1((A2A3)A4)),
((A1A2)(A3A4)),
(((A1A2)A3)A4),
((A1(A2A3))A4)
我们知道矩阵乘法是满足结合律的,所以以上五个式子的乘积相等,但是它们的运算时间是否相等呢?
矩阵乘法的运算时间
我们知道,矩阵乘法的定义是:
两个互相兼容的矩阵A,B可以相乘。互相兼容是指A的列数与B的行数相等。假如A是一个p×q的矩阵,而B是一个q×r的矩阵,则乘积C是一个p×r的矩阵且有
cij = ∑ aik·bkj, k = 1,...,q.
由于要对C中的每一个元素进行计算(共q·r个元素),而每次运算要做q次乘法,所以总的运算时间为pqr。
来看看让乘积中的不同因子结合对运算时间有什么影响。假设我们有 <A1,A2,A3>这个矩阵链,三个矩阵的秩分别为10×100, 100×5和5×50。则
- ((A1A2)A3)的运算时间为10×100×5+10×5×50=7500;
- (A1(A2A3))的运算时间为100×5×50+10×100×50=75000。
按照不同的顺序做矩阵乘法,所需要的乘法次数竟相差10倍。
初步分析
按照惯例,我们来感受一下穷举的算法复杂度。
假设有一个长度为n的矩阵链,我们通过遍历所有的全括号化的可能性来解题。设全括号化的可能性数目为P(n)。当n为1时,矩阵链只有一个矩阵,符合全括号化的定义;当n>=2时,全括号化后为两个矩阵的乘积,即((...)(...))的形式。用递归的思路去分析,则中间两个括号的分界位置有n-1种可能,如下面竖线所示
A1|A2|A3|...|An
当分界线将矩阵链分为长度为k和n-k的两个子矩阵链时,全括号化可能性为P(k)P(n-k)。我们对所有的k值求和,就得出给整个矩阵链全括号化的数目:
P(n) = ∑ P(k)P(n-k), k=1...n-1 (n>=2)
这是一个卡塔兰数(Catalan Number),它的增长速率为Ω(4n/n3/2),它的渐进值为Ω(2n)。
(对渐进值还不太熟,如果有小伙伴明白“增长速率”和“渐进值”之间的关系,欢迎指教。)
总的来说,如果对这个题目使用穷举法,算法复杂度是指数的。后面我们分析了dp的算法复杂度,再来比较。
用dp方法论解题
算法的学习永远没有“手把手”这一说。如果你在认真学习这篇文章,希望你能做到比你看到的小节思路提前一点。比如,在看第一步前,先对这个题目有一点大致思路,明白让自己迷茫的点在哪里;看第x步前,对第x步的内容在心中有一个猜测。这样做比起完全放弃思考,只是跟着文章的思路走,收获会大很多。
第一步:分析最优解的特征
这一步的精髓是分析最优子解如何构成最优解。
在上一节中已经提到,对于n>=2的情况,全括号化后为((chain_1)(chain_2))的形式。这样,问题自然而然地分成了两个子问题:求前后两个子括号中的最优解。
假设对于某种特定的分割(即chain_1和chain_2之间的分界线位置固定),chain_1的秩为m×p,其内部的标量乘法数目为x;chain_2的秩为p×n,其内部的标量乘法数目为y。则整个矩阵链的乘法次数为x+y+mpn。由于m,p,n是固定的,我们需要让x和y为最小值从而使整个矩阵链的乘法次数最小。即,对于某种特定的分割,两个子括号中的最优解构成整个问题的最优解的一个选项。
总结来说,我们将矩阵乘积简略地看成两个子矩阵链的乘积,这两个子矩阵链的分界有n-1种可能。对每一种可能,问题被分割成两个子问题,即求左右两个子矩阵链的最优解。如果遍历这n-1种可能并选出最好的一个,那就是整个问题的最优解。
第二步:递归地定义最优解的值
第二步非常关键,是我们将前后思路打通的一步。
第一步中提出了一个比较简单的思路,即把矩阵链分割成左右两个子矩阵链。既然有了这个初步思路,我们就来涂鸦一番,看看这个思路是否可行。
对于递归性的问题,一个很好的方法是画递归树,这样会使得问题看起来比较具象,而且也会暴露一些算法上的问题,比如重叠子树等。画递归树的时候,最好举一个实际的例子。这里我们假设有一个长度为4的矩阵链<A1,A2,A3,A4>,简单地画一下它的子问题分割:
上图中的数字表示子矩阵链的长度,根为4,即初始矩阵链;它可以分为1+3,2+2,3+1三种情况,这三种情况又可以各自细分。
这里暴露了一个问题,请看图中的两个涂色的子树。两个子树的节点数字是一样的。但是左边这个子树的根节点3代表的是A2A3A4这个乘积;而右边这个代表的是A1A2A3这个乘积。由于A1,A2,A3,A4四个矩阵的秩是未知的,它们很可能不相同,则A1A2A3和A2A3A4的最优解也很有可能不同。换言之,它们并不是同一个子问题,它们的子子树也并不相同。
这个问题意味着我们对子问题的定义不够严谨——子问题不能只用长度这个变量来确定。也就是说,如果在bottom-up的dp中用一个数组记录子问题的值,那么这个数组应该是一个二维数组。子问题不仅应该由子矩阵链的长度确定,还要加上起始index这样的信息。
为了更通用一些,我们不用起始index+长度,而选用起始index+结束index的定义方法,这是二维dp的惯用套路,在许多字符串和数组有关的问题中都有用到。
设用一个二位矩阵dp[][]存取子问题的解。定义dp[i][j](1<=i<=j<=n)的值为Ai...Aj的最小乘法次数。则按照以上的思路,可以把Ai...Aj再递归细分为子问题Ai...Ak和Ak+1...Aj(i<=k<j),则Ai...Aj的最优解值为两个子问题最优解的和+两个子矩阵链相乘的乘法次数。即有
i==j时,dp[i][j] = 0;
i <j时,dp[i][j] = min{dp[i][k] + dp[k+1][j] + pi-1pkpj}, k = i...j-1 (p为各个矩阵的秩,见题目一节)
到此为止,最关键的一步顺利完成啦(楼主写得好累,击掌╭(○`∀´○)╯╰(○'◡'○)╮)。在这一步中,我们递归地定义了子问题最优解的值,完成了算法最核心的设计部分。在后面两步中,我们只要把上面这两个式子翻译成代码,再注意一些实现细节就可以了。
第三步:计算最优解的值
细节一
从第二步顺理成章,我们会在一个二维数组里记录子问题的解。但是按照什么顺序去填这个二维数组是个问题。
还是举例子,在<A1,A2,A3,A4>这个矩阵链中,我们会有一个5×5的二维数组,随便挑选dp[1][4]这个元素举例。根据第二步中的状态转移方程,有
dp[1][4] = min{(dp[1][1]+dp[2][4]+...),(dp[1][2]+dp[3][4]+...),(dp[1][3]+dp[4][4]+...)}
省略号表示我们此处不需关注pi-1pkpj这一项,只需要看这个格子对其它格子的依赖是什么样子。
由上图可以看出,要计算某一个元素(粉色边框),我们需要其左边和下面的元素(同样深度的蓝色表示一组数据)。
所以,我们的遍历方向是从下到上,从左到右。
细节二
细心的读者可能注意到还有一个问题,就是我们一直在求“最优解的值”,也就是“最小的乘法次数”,可是题目中要求的是“最优解”,也就是“加括号的方式”。
这两者并不矛盾,专注于求解前者可以让我们先思考相对简单的问题,通常在求解前者的过程中,我们也找出了后者,只是没有将它记录下来。
在此题中,我们可以选择用一个同样的二维矩阵s[][]来记录后者,其中s[i][j]中记录Ai...Aj的分割分界线k。
代码
1 int matrixChain(int[] p){ 2 int n = p.length - 1; //number of matrices 3 int[][] dp = new int[n + 1][n + 1]; //we need dp[1][n] 4 int[][] s = new int[n + 1][n + 1]; //for storing of k 5 for(int[] row : dp) 6 Arrays.fill(row, Integer.MAX_VALUE); 7 8 for(int i = 1; i <= n; i++) 9 dp[i][i] = 0; //dp[i][j] = 0 when i == j 10 11 for(int i = n; i >= 1; i--) 12 for(int j = i; j <= n; j++){ 13 if(i == j){ 14 dp[i][j] = 0; 15 }else{ 16 for(int k = i; k < j; k++){ 17 int count = dp[i][k] + dp[k+1][j] + p[i-1]*p[k]*p[j]; 18 if(count < dp[i][j]){ 19 dp[i][j] = count; //record optimal solution value 20 s[i][j] = k; //record splitting point k 21 } 22 } 23 } 24 } 25 return dp[1][n]; 26 }
运行一个例子:
即输入的数组p为{30,35,15,5,10,20,25}。
如果在return之前打印出dp[][]和s[][]的值,结果为:
从左图可看出最优解为dp[1][6] = 15,125,即最少可以进行一万五千多次乘法。右图记录了对于每一个[i,j]决定的子矩阵链如何进行括号分割。
顺便分享一个ArrayPrinter的util,可以直接用,能打印出上图那样的二维int数组。
1 public class ArrayPrinter { 2 public static void print(int[] arr){ 3 printReplacing(false, arr, 0,""); 4 } 5 6 public static void print(int[][] matrix){ 7 printReplacing(false, matrix, 0,""); 8 } 9 10 public static void printReplacing(int[] arr, int before, String after){ 11 printReplacing(true, arr, before, after); 12 } 13 14 public static void printReplacing(int[][] matrix, int before, String after){ 15 printReplacing(true, matrix, before, after); 16 } 17 18 /*--------------------------private utils-------------------------------*/ 19 20 private static void printReplacing(boolean replace, int[] arr, int before, String after){ 21 int maxLen = maxLength(arr); 22 if(replace){ 23 for(int i : arr) 24 print(((i==before)?after:number(i)), maxLen); 25 }else{ 26 for(int i : arr) 27 print(number(i), maxLen); 28 } 29 print("\n", maxLen); 30 } 31 32 public static void printReplacing(boolean replace, int[][] matrix, int before, String after){ 33 int maxLen = maxLength(matrix); 34 if(replace){ 35 for(int[] row : matrix){ 36 for(int i : row) 37 print(((i==before)?after:number(i)), maxLen); 38 print("\n", maxLen); 39 } 40 }else{ 41 for(int[] row : matrix){ 42 for(int i : row) 43 print(number(i), maxLen); 44 print("\n", maxLen); 45 } 46 } 47 } 48 49 private static int maxLength(int[] arr){ 50 int maxLen = 0; 51 for(int aint : arr) 52 maxLen = Math.max(Integer.toString(aint).length(), maxLen); 53 return maxLen; 54 } 55 56 private static int maxLength(int[][] matrix){ 57 int maxLen = 0; 58 for(int row[] : matrix) 59 maxLen = Math.max(maxLength(row), maxLen); 60 return maxLen; 61 } 62 63 //actual printing 64 private static void print(String s, int length){ 65 System.out.print(String.format("%1$"+(length+1)+"s", s)); 66 } 67 68 //formatting of number 69 private static String number(int i){ 70 return NumberFormat.getNumberInstance(Locale.US).format(i); 71 } 72 }
使用方法:
1 ArrayPrinter.printReplacing(dp, Integer.MAX_VALUE, "/"); 2 ArrayPrinter.print(s);
第四步:根据计算好的信息构造最优解
还差一步就大功告成。这一步我们要拿着上一步计算出的矩阵s把最终的全括号矩阵乘积打印出来。递归打印即可。
1 private void printParenthesis(int[][] s, int i, int j) { 2 if(i == j) 3 print("A"+i); 4 else{ 5 print("("); 6 printParenthesis(s, i, s[i][j]); 7 printParenthesis(s, s[i][j]+1, j); 8 print(")"); 9 } 10 }
打印结果:
复杂度
前面说过,穷举法的复杂度大概是O(2n)。在以上的dp算法中,主算法需要填满一个(n+1)×(n+1)的二维数组的上半部分,每填一个元素需要一个长度为j-i的循环,可通过这个思路对j-i进行求和(i=0...n, j=i...n),也可以通过大概估算得到时间复杂度为O(n3),远好于穷举法。
空间复杂度主要由二维数组决定,为O(n2)。
总结
本文主要介绍了解一个dp问题的思路。
dp问题一般有两个显著特点,这一点下一篇会详细讲述:
- 问题的最优解由子问题的最优解构成
- 子问题互相重叠
也再复习一下解题的四个步骤,看你现在有没有更深刻的理解:
- 分析最优解的特征。 (分析最优子解如何构成最优解)
- 递归地定义最优解的值。 (画递归树,定义子问题,写状态转移方程)
- 计算最优解的值。 (写代码求出最优解,如果有要求的话,记录额外信息,为第4步作准备)
- 根据计算好的信息构造最优解。 (从第3步记录的信息中构建最优解,在本题中就是括号的写法)