数位dp通用模板
记忆化搜索:dfs
dfs框架为:1、结束条件 2、扩展状态(枚举可能情况) 3、返回结果
数位dp通常使用dfs来实现
二进制代表状态:
0代表不在集合中,1代表在集合中。
判断第k个元素是否在集合中: mask & (1 << k) 是否为1
将第k个元素加入到集合中: mask | (1 << k)
完整的模板大概如下:
class Solution: def numDupDigitsAtMostN(self, n: int) -> int: s = str(n) @cache def f(i: int, mask: int, is_limit: bool, is_num: bool) -> int: if i == len(s): return int(is_num) res = 0 if not is_num: res = f(i + 1, mask, False, False) low = 0 if is_num else 1 up = int(s[i]) if is_limit else 9 for d in range(low, up + 1): if (mask >> d & 1) == 0: res += f(i + 1, mask | (1 << d), is_limit and d == up, True) return res return n - f(0, 0, True, False)
参数i:当前搜索到了第i个位置
mask:当前集合中的状态
is_limit :当前位置填入数字的时候是否收到约束,比如最大数为123,若我们第一位填了1,第二位最大只能填入2,此时is_limit为true 若我们第一位填了0,则第二位可以填1 - 9,此时is_limit为false。
is_num:当前位置之前是否已经填了元素,若为false,则可以跳过。如n = 1125, 若is_num连续三次为false则可以枚举个位,is_num连续两次为false可以枚举十位
例题:
class Solution: def countDigitOne(self, n: int) -> int: s = str(n) @cache def f(i: int, cnt: int, is_limit: bool): if i == len(s): return cnt res = 0 up = int(s[i]) if is_limit else 9 for d in range(up + 1): res += f(i + 1, cnt + (d == 1), is_limit and d == up) return res return f(0, 0, True)
class Solution: def atMostNGivenDigitSet(self, digits: List[str], n: int) -> int: s = str(n) @cache def f(i: int, is_limit: bool, is_num: bool) -> int: if i == len(s): return int(is_num) res = 0 if not is_num: res = f(i + 1, False, False) up = s[i] if is_limit else '9' for d in digits: if d > up: break res += f(i + 1, is_limit and d == up, True) return res return f(0, True, False)
class Solution: def numDupDigitsAtMostN(self, n: int) -> int: s = str(n) @cache def f(i: int, mask: int, is_limit: bool, is_num: bool) -> int: if i == len(s): return int(is_num) res = 0 if not is_num: res = f(i + 1, mask, False, False) low = 0 if is_num else 1 up = int(s[i]) if is_limit else 9 for d in range(low, up + 1): if (mask >> d & 1) == 0: res += f(i + 1, mask | (1 << d), is_limit and d == up, True) return res return n - f(0, 0, True, False)
class Solution: def findIntegers(self, n: int) -> int: s = str(bin(n))[2:] @cache def f(i: int, pre: bool, is_limit: bool) -> int: if i == len(s): return 1 up = int(s[i]) if is_limit else 1 res = f(i + 1, False, is_limit and up == 0) if not pre and up == 1: res += f(i + 1, True, is_limit) return res return f(0, False, True)
本篇学习自bi站灵神(灵茶山艾府)的周赛讲解,小伙伴们可以去看看大佬讲解。