类型题Ⅴ:回溯法
类型题Ⅴ:回溯法
文章目录
- 题型一:排列、组合、子集相关问题-
回溯法也称作暴搜(暴力搜索),本质是穷举状态空间所有可能。
回溯法解题框架:
result = [] def backtrack(路径, 已选择列表): if 满足结束条件: result.add(路径) return for 选择 in 选择列表: # 做选择 将该选择从选择列表移除 路径.add(选择) # 进入下一层选择 backtrack(路径, 选择列表) # 撤销选择 路径.remove(选择) 将该选择再加入选择列表
回溯算法的时间复杂度很高,在遍历的时候,如果能够提前知道这一条分支不可能搜索到满意的结果,就可以提前结束,这一步操作称为剪枝。有时候剪枝需要做一些预处理工作(例如排序)。预处理工作虽然也消耗时间,但剪枝能够节约的时间更多。剪枝是一种技巧,通常需要根据不同问题场景采用不同的剪枝策略,需要在做题的过程中不断总结。
题型一:排列、组合、子集相关问题
排列问题元素有序,[2, 2, 3] 和 [2, 3, 2] 是两个不同的结果,元素不可以重复选择,需要用状态位辅助。
组合问题元素无序,[2, 2, 3] 和 [2, 3, 2] 是两个相同的结果,需要用起始点辅助剪枝。1. 全排列
全排列
46. 全排列:给定一个没有重复数字的序列,返回其所有可能的全排列。
同一个数字不能重复使用,使用一个标记位判断当前数字是否已经使用过即可。
时间复杂度:
O ( n × n ! ) O(n \times n!) O(n×n!)<br> 空间复杂度: O ( n × n ! ) O(n \times n!) O(n×n!),共 n ! n! n! 个全排列,每个全排列占空间 O ( n ) O(n) O(n)
# 解法一 class Solution: def permute(self, nums: List[int]) -> List[List[int]]: def dfs(index, cur): if index == n: res.append(cur[:]) # 注意这里要用cur[:]对cur进行copy return for i in range(n): if nums[i] != '#': tmp = nums[i] cur.append(tmp) # 做选择 nums[i] = '#' # 将该选择从选择列表移除 dfs(index+1, cur) # 进入下一层选择 nums[i] = tmp # 将该选择再加入选择列表 cur.pop() # 撤销选择 res = [] n = len(nums) dfs(0, []) return res
注意第一和下面第二种解法细微的差别,在于下一层的选择列表如何设置,对应的搜索结束条件也因此不同。(不需要标记位,感觉这种代码更简洁一些)
# 解法二 class Solution: def permute(self, nums: List[int]) -> List[List[int]]: def backtrack(nums, tmp): if not nums: res.append(tmp) return for i in range(len(nums)): backtrack(nums[:i] + nums[i+1:], tmp + [nums[i]]) res = [] backtrack(nums, []) return res
还可以将数组划分成左右两个部分,左边的表示已经填过的数,右边表示待填的数,在递归搜索的时候只要动态维护这个数组即可。(不需要标记位)
# 解法三 class Solution: def permute(self, nums): def backtrack(first = 0): if first == n: res.append(nums[:]) for i in range(first, n): nums[first], nums[i] = nums[i], nums[first] # 做交换 backtrack(first + 1) nums[first], nums[i] = nums[i], nums[first] # 撤销交换 n = len(nums) res = [] backtrack() return res
全排列 II
47. 全排列 II:给定一个可包含重复数字的序列,返回所有不重复的全排列。
对于数组中存在重复数字的情况,就需要剪枝。就本题来说,对于“同一层”的选择,如果出现一摸一样的路径,就是应该被剪掉的。可以先对数组进行排序,在同层选择时,如果和前面是重复的,就没有必要再选了。
# 标记位 + 剪枝 class Solution: def permuteUnique(self, nums: List[int]) -> List[List[int]]: def dfs(index, arr): if index == n: res.append(arr[:]) return for i in range(n): if i > 0 and nums[i] == nums[i-1]: # 剪枝 continue cur = nums[i] if cur != '#': nums[i] = '#' arr.append(cur) dfs(index+1, arr) arr.pop() nums[i] = cur n = len(nums) nums.sort() # 进行剪枝的前处理 res = [] dfs(0, []) return res
字符串的排列
剑指 Offer 38. 字符串的排列:输入一个字符串,打印出该字符串中字符的所有排列。
实际就是将数字换成字符的全排列问题。
class Solution: def permutation(self, s: str) -> List[str]: def dfs(index, arr): if index == n: res.append(''.join(arr[:])) return for i in range(n): if i > 0 and strs[i] == strs[i-1]: continue cur = strs[i] if cur != '#': arr.append(cur) strs[i] = '#' dfs(index+1, arr) strs[i] = cur arr.pop() res = [] strs = list(s) strs.sort() n = len(strs) dfs(0, []) return res
2. 组合
组合总和
39. 组合总和:给定一个无重复元素的数组 candidates 和一个目标数 target ,找出 candidates 中所有可以使数字和为 target 的组合。candidates 中的数字可以无限制重复选取。
# 解法一:利用元组set去重(不推荐) class Solution: def combinationSum(self, candidates: List[int], target: int) -> List[List[int]]: def dfs(num, cur): if num == target: res.append(cur[:]) return for i in range(n): if num+candidates[i] <= target: cur.append(candidates[i]) dfs(num+candidates[i], cur) cur.pop() n = len(candidates) res = [] dfs(0, []) # 利用元组去重 sets = set() for item in res: item.sort() sets.add(tuple(item)) return list(sets)
没有剪枝之前运行时间会很长,下面进行剪枝处理,具体操作是:每做一个选择,后续的可选择列表,只能是当前选择位置右侧(包含当前位置),所以可以用一个 index 来指示当前的选择位置,在递归中进行传参。
# 解法二:剪枝 class Solution: def combinationSum(self, candidates: List[int], target: int) -> List[List[int]]: def dfs(num, cur, index): if num == target: res.append(cur[:]) return for i in range(index, n): # 用index进行辅助剪枝 if num+candidates[i] <= target: cur.append(candidates[i]) dfs(num+candidates[i], cur, i) # 设置下一轮的起点 cur.pop() n = len(candidates) res = [] dfs(0, [], 0) return res
对剪枝进行优化,可以先对原数组进行排序,如果当前的数加上之后大于 target,那么再向后也不可能找到了(因为是递增的),直接可以 break 掉。
# 解法三:剪枝优化 class Solution: def combinationSum(self, candidates: List[int], target: int) -> List[List[int]]: def dfs(num, cur, index): if num == target: res.append(cur[:]) return for i in range(index, n): if num+candidates[i] <= target: cur.append(candidates[i]) dfs(num+candidates[i], cur, i) cur.pop() else: break # 如果当前不满足,后面的也不可能满足,直接break掉 n = len(candidates) candidates.sort() # 先进行排序 res = [] dfs(0, [], 0) return res
组合总和 II
40. 组合总和 II:给定一个数组 candidates 和一个目标数 target ,找出 candidates 中所有可以使数字和为 target 的组合。candidates 中的每个数字在每个组合中只能使用一次。
由于只能使用一次,所以用状态位进行标记,另外,组合问题会有重复解,不推荐的去重方法是单独去重。
# 解法一:单独去重(不推荐) class Solution: def combinationSum2(self, candidates: List[int], target: int) -> List[List[int]]: def dfs(num, cur): if num == target: res.append(cur[:]) return for i in range(n): tmp = candidates[i] if tmp != '#' and (num+tmp) <= target: cur.append(tmp) candidates[i] = '#' dfs(num+tmp, cur) cur.pop() candidates[i] = tmp res = [] n = len(candidates) dfs(0, []) # 去重 sets = set() for item in res: item.sort() sets.add(tuple(item)) res = list(sets) return res
进行剪枝,思路和前面一样,用起始点 start 辅助剪枝,每次只能从当前位置右侧的待选列表中选。另外,发现数组中是有重复数字的,所以还要用排序辅助去重。
# 解法二:剪枝 class Solution: def combinationSum2(self, candidates: List[int], target: int) -> List[List[int]]: def dfs(num, cur, start): if num == target: res.append(cur[:]) return for i in range(start, n): if i > 0 and candidates[i] == candidates[i-1]: # 排序去重 continue tmp = candidates[i] if tmp != '#' and (num+tmp) <= target: cur.append(tmp) candidates[i] = '#' dfs(num+tmp, cur, i+1) # 起始点去重 cur.pop() candidates[i] = tmp else: break # 排序后的剪枝优化 res = [] n = len(candidates) candidates.sort() dfs(0, [], 0) return res
组合总和 III
216. 组合总和 III:找出所有相加之和为 n 的 k 个数的组合。组合中只允许含有 1 - 9 的正整数,并且每种组合中不存在重复的数字。解集不能包含重复的组合。
class Solution: def combinationSum3(self, k: int, n: int) -> List[List[int]]: def dfs(index, cnt, start, arr): if index == k and cnt == n: res.append(arr[:]) return for i in range(start, 9): cur = nums[i] if cur != '#' and (cur+cnt) <= n: arr.append(cur) nums[i] = '#' dfs(index+1, cur+cnt, i+1, arr) nums[i] = cur arr.pop() nums = [i for i in range(1, 10)] res = [] dfs(0, 0, 0, []) return res
组合
77. 组合:给定两个整数 n 和 k,返回 1 … n 中所有可能的 k 个数的组合。
这道题就是限制了个数的组合问题,递归结束条件就是当前已经有 k 个数,为了剪枝(去重),每次寻找时都从 start 位置开始向后寻找即可。
class Solution: def combine(self, n: int, k: int) -> List[List[int]]: def dfs(cnt, start, cur): if cnt == k: res.append(cur[:]) return for i in range(start, n+1): cur.append(i) dfs(cnt+1, i+1, cur) # 剪枝 cur.pop() res = [] dfs(0, 1, []) return res
进一步优化,根据 n 和 k,可以推出搜索上界,即在这个界限之后再进行枚举是没有意义的,因为数不够用,比如 n = 5,k = 3 时,搜索上界就是 4,此时以 4 和 5 开头的路径都是可以直接剪掉的。
class Solution: def combine(self, n: int, k: int) -> List[List[int]]: def dfs(cnt, start, cur): if cnt == k: res.append(cur[:]) return for i in range(start, n+1): if n-i >= k-cnt-1: # 如果当前位置后面的元素个数小于剩余需要选取的个数,则直接剪枝 cur.append(i) dfs(cnt+1, i+1, cur) cur.pop() res = [] dfs(0, 1, []) return res
子集
78. 子集:给定一组不含重复元素的整数数组 nums,返回该数组所有可能的子集(幂集)。解集不能包含重复的子集。
子集包含的元素可以从 0 个到 n 个,可以用递归分别对这 n+1 种情况进行处理,但是画出决策树发现这样会有很多的重复计算,因为在枚举 2 个数的子集时,其实是基于 1 个数的子集的,如果像下面这样循环,实际每次都要从第一个数开始,不推荐。
class Solution: def subsets(self, nums: List[int]) -> List[List[int]]: def dfs(index, arr, cnt, start): if index == cnt: res.append(arr[:]) return for i in range(start, n): cur = nums[i] if cur != '#': arr.append(cur) nums[i] = '#' dfs(index+1, arr, cnt, i+1) nums[i] = cur arr.pop() res = [] n = len(nums) for i in range(n+1): dfs(0, [], i, 0) # 枚举n+1种情况,具有很多重复计算 return res
进一步优化,大子集总是从小子集扩展而来的,只要每添加一个元素就向结果集中 append 即可,注意元素不能重复选取,且元素间无序,所以用 start 标记开始位置。
class Solution: def subsets(self, nums: List[int]) -> List[List[int]]: def dfs(start, arr): if start == n: return for i in range(start, n): arr.append(nums[i]) res.append(arr[:]) # 每更新一次就向结果集中添加 dfs(i+1, arr) arr.pop() res = [[]] n = len(nums) dfs(0, []) return res
子集 II
90. 子集 II:给定一个可能包含重复元素的整数数组 nums,返回该数组所有可能的子集(幂集)。解集不能包含重复的子集。
对于可能包含重复元素的情况,一般会利用排序剪枝,其余思路和【8】一样。
class Solution: def subsetsWithDup(self, nums: List[int]) -> List[List[int]]: def dfs(start, arr): if start == n: return for i in range(start, n): # 利用排序剪枝,针对原数组可能包含的重复元素 if i > start and nums[i] == nums[i-1]: continue arr.append(nums[i]) res.append(arr[:]) dfs(i+1, arr) # 利用start剪枝,针对重复子集 arr.pop() res = [[]] n = len(nums) nums.sort() # 利用排序剪枝 dfs(0, []) return res
第k个排列
60. 第k个排列:给定 n 和 k,按大小顺序列出 [1,2,3,…,n] 的所有排列,返回第 k 个排列。
按照朴素思路,按顺序找,找到第 k 个时,直接 break 掉返回即可。这里用一个 cnt 变量保存当前已有多少个结果,当 cnt 等于 k 时,就说明找到了第 k 个,但是这种简单粗暴的方法会计算超时。
# 解法一:暴力搜索 class Solution: def getPermutation(self, n: int, k: int) -> str: def dfs(index, arr, cnt): if index == n: cnt += 1 res.append(arr[:]) return for i in range(n): cur = nums[i] if cur != '#': arr.append(str(cur)) nums[i] = '#' dfs(index+1, arr, cnt) if cnt == k: break arr.pop() nums[i] = cur nums = [i for i in range(1, n+1)] res = [] dfs(0, [], 0) return ''.join(res[k-1])
寻找剪枝条件:实际根据 n 就能知道有没有必要向下找,因为对决策树某一层的节点来说,每个节点对应的最底层叶子节点个数等于剩余数字个数的阶乘。
所求排列 一定在叶子结点处得到,进入每一个分支,可以根据已经选定的数的个数,进而计算还未选定的数的个数,然后计算阶乘,就知道这一个分支的 叶子结点 的个数:
如果 k 大于这一个分支将要产生的叶子结点数,直接跳过这个分支,这个操作叫「剪枝」;- 如果 k 小于等于这一个分支将要产生的叶子结点数,那说明所求的全排列一定在这一个分支将要产生的叶子结点里,需要递归求解。 ``` # 解法二:剪枝 class Solution: def getPermutation(self, n: int, k: int) -> str:
# 创建n!数组 cnt = [1 for i in range(n)] for i in range(1, n): cnt[i] = cnt[i - 1] * i def dfs(index, arr, rem): if index == n: return for i in range(n): cur = nums[i] # 搜索当前位置的数 if cur != '#': if cnt[n - index - 1] < rem: rem -= cnt[n - index - 1] continue arr.append(str(cur)) nums[i] = '#' dfs(index + 1, arr, rem) res = ''.join(arr) return res nums = [i for i in range(1, n+1)] res = dfs(0, [], k) return res
### 复原IP地址 [93. 复原IP地址](https://leetcode-cn.com/problems/restore-ip-addresses/):给定一个只包含数字的字符串,复原它并返回所有可能的 IP 地址格式。 定义辅助函数 check 校验当前枚举数字的合法性,要求数字范围在 0~255 之间,且 0 只可以单独使用,不能作前导 0。 在递归函数里,用 index 指示当前深度,cnt 指示当前是 IP 地址的第几个数字,用 index 和 cnt 共同作为找到目标答案的返回条件。由于是十进制,一个数字最长就是三位,所以在 for 循环中枚举三种长度即可。注意剪枝条件,如果从当前 index 开始的数字不合法,那么向后继续查找也不会合法,可以直接 break 掉。另外,在程序最开始先判断,如果字符串长度大于 12,那么一定找不到答案。
class Solution: def restoreIpAddresses(self, s: str):
def check(str): if len(str) > 1 and str[0] == '0': return False return 0 <= int(str) <= 255 def dfs(index, trace, rem, cnt): if index == n: if cnt == 4: res.append(trace) return if trace: trace += '.' for i in range(3): if len(rem) > i and check(rem[:i+1]): dfs(index + i + 1, trace + rem[:i + 1], rem[i + 1:], cnt + 1) else: break res = [] n = len(s) if n > 12: return [] dfs(0, '', s, 0) return res
# 题型二:字符串中的回溯问题 [22. 括号生成](https://leetcode-cn.com/problems/generate-parentheses/submissions/):数字 n 代表生成括号的对数,请你设计一个函数,用于能够生成所有可能的并且 有效的 括号组合。 利用加减法来做,判断的各种条件都基于左、右括号的数量。注意左括号只要有剩余个数就能往里加,而右括号除此之外还受到左括号的限制,也就是已经添加进来的右括号个数不能大于已经添加进来的左括号个数,同理也就是剩余的左括号个数不能大于剩余的右括号个数。 这里就没有进行显式的 append 和 pop 过程了。
class Solution: def generateParenthesis(self, n: int) -> List[str]:
def dfs(arr, left, right): if left == 0 and right == 0: res.append(arr) return if left > right: # 剩余的左括号个数不能大于剩余的右括号个数,否则是无效串 return if left > 0: dfs(arr + '(', left-1, right) if right > 0: dfs(arr + ')', left, right-1) res = [] dfs('', n, n) return res
[17. 电话号码的字母组合](https://leetcode-cn.com/problems/letter-combinations-of-a-phone-number/):给定一个仅包含数字 2-9 的字符串,返回所有它能表示的字母组合。 常规做法,用 index 指示当前数字个数,作为递归结束条件。
解法一
class Solution: def letterCombinations(self, digits: str) -> List[str]: if not digits: return list()
phoneMap = {<!-- --> "2": "abc", "3": "def", "4": "ghi", "5": "jkl", "6": "mno", "7": "pqrs", "8": "tuv", "9": "wxyz", } def backtrack(index: int): if index == len(digits): combinations.append("".join(combination)) else: digit = digits[index] for letter in phoneMap[digit]: combination.append(letter) backtrack(index + 1) combination.pop() combination = list() combinations = list() backtrack(0) return combinations
另一种方式,用剩余 num 作为递归结束条件:
解法二
class Solution: def letterCombinations(self, digits: str) -> List[str]: if not digits: return []
phoneMap = {<!-- --> "2": "abc", "3": "def", "4": "ghi", "5": "jkl", "6": "mno", "7": "pqrs", "8": "tuv", "9": "wxyz", } def dfs(trace, nums): if not nums: res.append(trace) return for item in phoneMap[nums[0]]: dfs(trace+item, nums[1:]) res = [] dfs('', digits) return res
# 题型三:二维矩阵中的搜索问题 [79. 单词搜索](https://leetcode-cn.com/problems/word-search/):给定一个二维网格和一个单词,找出该单词是否存在于网格中,字母**不允许重复使用**。 首先找到与 word 首字母匹配的起始点,递归进入查找。在深搜函数里,传入当前位置坐标 i,j,另外传入当前 word 的计数下标 index。搜索的边界条件:1)当 index 等于 word 长度时,表示搜索成功,返回 True(注意 index 从 0 开始,当 index = word 长度时,说明 word 中所有字母均被搜索过了);2)当搜索位置不合法(越界)或者 搜索位置字母与当前目标字母不匹配 或者 搜索位置上的字母已经被使用过时,搜索失败返回 False。 这里使用标记 # 直接对数组元素进行修改,并在撤销选择时对其进行恢复。
class Solution: def exist(self, board: List[List[str]], word: str) -> bool:
def find(index, i, j): if index == num: # 搜索成功 return True if i < 0 or i >= r or j < 0 or j >= c: # 搜索位置不合法 return False if board[i][j] == word[index]: # 搜索位置匹配 tmp = board[i][j] board[i][j] = '#' res = find(index+1, i-1, j) or find(index+1, i+1, j) or find(index+1, i, j-1) or find(index+1, i, j+1) board[i][j] = tmp return res return False r, c, num = len(board), len(board[0]), len(word) res = False for i in range(r): for j in range(c): if board[i][j] == word[0]: res = res or find(0, i, j) return res
另外,偏移量数组在二维平面搜索问题中比较常用,可以按照下面方法进行设置:
class Solution: def exist(self, board: List[List[str]], word: str) -> bool:
directions = [(-1,0), (1,0), (0,-1), (0,1)] # 分别对应上下左右四个方向 def find(index, i, j): if index == num: return True if i < 0 or i >= r or j < 0 or j >= c: return False if board[i][j] == word[index]: tmp = board[i][j] board[i][j] = '#' res = False for di, dj in directions: newi, newj = i + di, j + dj res = res or find(index+1, newi, newj) board[i][j] = tmp return res return False r, c = len(board), len(board[0]) num = len(word) res = False for i in range(r): for j in range(c): if board[i][j] == word[0]: res = res or find(0, i, j) return res
```