看这个就够了

导航

从递归到记忆化搜索到动态规划

动态规划的状态转移方程一般不容易找出来,并且两个变量的动态规划也不容易直接写出,我以leetcode No.300 最长递增子序列为例,总结一下是如何一步步从最开始的递归做法到记忆化搜索再到动态规划的。
首先题目如下:
给你一个整数数组 nums ,找到其中最长严格递增子序列的长度。
子序列是由数组派生而来的序列,删除(或不删除)数组中的元素而不改变其余元素的顺序。例如,[3,6,2,7] 是数组 [0,3,1,6,2,2,7] 的子序列。

示例 1:
输入:nums = [10,9,2,5,3,7,101,18]
输出:4
解释:最长递增子序列是 [2,3,7,101],因此长度为 4 。

示例 2:
输入:nums = [0,1,0,3,2,3]
输出:4

示例 3:
输入:nums = [7,7,7,7,7,7,7]
输出:1

递归

如果一眼动态规划的题目没有思路,不妨先从递归开始一步步变成动态规划

以[1,3,2,5,7,4]为例

暴力解法(递归):

  • 从1开始的话,要选比它大的数字,那么遍历后续数组下一个数字可以选3或者5或者7或者4

  • 选[1,3]后,继续遍历后续数组,直到最后一位。

  • 要注意的是每次选择数字后,递增序列长度就+1,然后要比较是否是最长序列。

  • 每一次遍历的过程就是一次递归。

  • 由于序列不一定从1开始,从头对每个数组都做一次递归。

代码

class Solution:

  def findLengthOfLCIS(self, nums) -> int:

     maxnum = 0

     for i in range(len(nums)):

       maxnum = max(maxnum,self.digui(nums,i))

     return maxnum

  

  def digui(self,nums,i):

     maxnum = 1



     if i == len(nums)-1:

       return 1

     for j in range(i+1,len(nums)):

       if nums[j] > nums[i]:

         maxnum = max(maxnum,self.digui(nums,j)+1)

     return maxnum

记忆化搜索

然后会发现,我进行了很多重复的计算,比如在最开始选1的时候,我计算了2的递增序列长度,在选[1,3]后我又要计算了2的递增长度

于是,我们可以创建一个memo数组,在第一次计算完2的递增序列长度后,将其保存下来,下次直接查找就可以使用。

这就是记忆化搜索。

代码

class Solution:

  memo = []

  def findLengthOfLCIS(self, nums) -> int:

     self.memo = [0]*len(nums)

     maxnum = 0

     for i in range(len(nums)):

       maxnum = max(maxnum,self.digui(nums,i))

     return maxnum

  

  def digui(self,nums,i):

     maxnum = 1

     if i in self.memo:

       return self.memo[i]

     if i == len(nums)-1:

       return 1

     for j in range(i+1,len(nums)):

       if nums[j] > nums[i]:

         maxnum = max(maxnum,self.digui(nums,j)+1)

     self.memo[i] = maxnum

     return maxnum

动态规划

我们注意到这个时候,已经几乎写出来了状态转移方程

if nums[j] > nums[i]:

maxnum = max(maxnum,self.digui(nums,j)+1)

memo数组就相当于dp数组,其中存放的是每个数的最大递增序列长度

我们从后往前遍历一遍nums就可以了

代码

最终动态规划代码:

class Solution:

  def lengthOfLIS(self, nums) -> int:

     if nums == []:

       return 0

     dp = [1]*len(nums)

     

     for i in range(len(nums)-1,-1,-1):

       #print(i)

       for j in range(i+1,len(nums)):

         if nums[i] < nums[j]:

           dp[i] = max(dp[i],dp[j] + 1)

     #print(dp)

     return max(dp)
	

posted on 2022-01-13 15:43  看这个就够了  阅读(181)  评论(0编辑  收藏  举报