动态规划系列(三)——高楼扔鸡蛋
一、解析题目
题目是这样:你面前有一栋从 1 到
N
共N
层的楼,然后给你K
个鸡蛋(K
至少为 1)。现在确定这栋楼存在楼层0 <= F <= N
,在这层楼将鸡蛋扔下去,鸡蛋恰好没摔碎(高于F
的楼层都会碎,低于F
的楼层都不会碎)。现在问你,最坏情况下,你至少要扔几次鸡蛋,才能确定这个楼层F
呢?
实际上,如果不限制鸡蛋个数的话,二分思路显然可以得到最少尝试的次数,但问题是,现在给你了鸡蛋个数的限制K
,直接使用二分思路就不行了。
比如说只给你 1 个鸡蛋,7 层楼,你敢用二分吗?你直接去第 4 层扔一下,如果鸡蛋没碎还好,但如果碎了你就没有鸡蛋继续测试了,无法确定鸡蛋恰好摔不碎的楼层F
了。这种情况下只能用线性扫描的方法,算法返回结果应该是 7。
有的读者也许会有这种想法:二分查找排除楼层的速度无疑是最快的,那干脆先用二分查找,等到只剩 1 个鸡蛋的时候再执行线性扫描,这样得到的结果是不是就是最少的扔鸡蛋次数呢?
很遗憾,并不是,比如说把楼层变高一些,100 层,给你 2 个鸡蛋,你在 50 层扔一下,碎了,那就只能线性扫描 1~49 层了,最坏情况下要扔 50 次。
二、思路分析
对动态规划问题,直接套我们以前多次强调的框架即可:这个问题有什么「状态」,有什么「选择」,然后穷举。
「状态」很明显,就是当前拥有的鸡蛋数K
和需要测试的楼层数N
。随着测试的进行,鸡蛋个数可能减少,楼层的搜索范围会减小,这就是状态的变化。
「选择」其实就是去选择哪层楼扔鸡蛋。回顾刚才的线性扫描和二分思路,二分查找每次选择到楼层区间的中间去扔鸡蛋,而线性扫描选择一层层向上测试。不同的选择会造成状态的转移。
现在明确了「状态」和「选择」,动态规划的基本思路就形成了:肯定是个二维的dp
数组或者带有两个状态参数的dp
函数来表示状态转移;外加一个 for 循环来遍历所有选择,择最优的选择更新结果 :
# 当前状态为 (K 个鸡蛋,N 层楼)
# 返回这个状态下的最优结果
def dp(K, N):
int res
for 1 <= i <= N:
res = min(res, 这次在第 i 层楼扔鸡蛋)
return res
这段伪码还没有展示递归和状态转移,不过大致的算法框架已经完成了。
我们在第i
层楼扔了鸡蛋之后,可能出现两种情况:鸡蛋碎了,鸡蛋没碎。注意,这时候状态转移就来了:
如果鸡蛋碎了,那么鸡蛋的个数K
应该减一,搜索的楼层区间应该从[1..N]
变为[1..i-1]
共i-1
层楼;
如果鸡蛋没碎,那么鸡蛋的个数K
不变,搜索的楼层区间应该从 [1..N]
变为[i+1..N]
共N-i
层楼。
因为我们要求的是最坏情况下扔鸡蛋的次数,所以鸡蛋在第i
层楼碎没碎,取决于那种情况的结果更大:
def dp(K, N):
for 1 <= i <= N:
# 最坏情况下的最少扔鸡蛋次数
res = min(res,
max(
dp(K - 1, i - 1), # 碎
dp(K, N - i) # 没碎
) + 1 # 在第 i 楼扔了一次
)
return res
递归的 base case 很容易理解:当楼层数N
等于 0 时,显然不需要扔鸡蛋;当鸡蛋数K
为 1 时,显然只能线性扫描所有楼层:
至此,其实这道题就解决了!只要添加一个备忘录消除重叠子问题即可:
def superEggDrop(K: int, N: int):
memo = dict()
def dp(K, N) -> int:
# base case
if K == 1: return N
if N == 0: return 0
# 避免重复计算
if (K, N) in memo:
return memo[(K, N)]
res = float('INF')
# 穷举所有可能的选择
for i in range(1, N + 1):
res = min(res,
max(
dp(K, N - i), # 没碎
dp(K - 1, i - 1) # 碎
) + 1
)
# 记入备忘录
memo[(K, N)] = res
return res
return dp(K, N)
这个算法的时间复杂度是多少呢?动态规划算法的时间复杂度就是子问题个数 × 函数本身的复杂度。
函数本身的复杂度就是忽略递归部分的复杂度,这里dp
函数中有一个 for 循环,所以函数本身的复杂度是 O(N)。
子问题个数也就是不同状态组合的总数,显然是两个状态的乘积,也就是 O(KN)。
所以算法的总时间复杂度是 O(K*N^2), 空间复杂度为子问题个数,即 O(KN)。
三、疑难解答
有读者可能不理解代码中为什么用一个 for 循环遍历楼层[1..N]
,也许会把这个逻辑和之前探讨的线性扫描混为一谈。其实不是的,这只是在做一次「选择」。
比方说你有 2 个鸡蛋,面对 10 层楼,你得拿一个鸡蛋去某一层楼扔对吧?那选择去哪一层楼扔呢?不知道,那就把这 10 层楼全试一遍。至于鸡蛋碎没碎,下次怎么选择不用你操心,有正确的状态转移,递归会算出每个选择的代价,我们取最优的那个就是最优解。
四、优化
一、二分搜索优化
核心是因为状态转移方程的单调性
这个 for 循环就是下面这个状态转移方程的具体代码实现:
二分查找的运用很广泛,形如下面这种形式的 for 循环代码:
for (int i = 0; i < n; i++) {
if (isOK(i))
return i;
}
都很有可能可以运用二分查找来优化线性搜索的复杂度,回顾这两个dp
函数的曲线,我们要找的最低点其实就是这种情况:
for (int i = 1; i <= N; i++) {
if (dp(K - 1, i - 1) == dp(K, N - i))
return dp(K, N - i);
}
这就是相当于求 Valley(山谷)值嘛,可以用二分查找来快速寻找这个点的
def superEggDrop(self, K: int, N: int) -> int:
memo = dict()
def dp(K, N):
if K == 1: return N
if N == 0: return 0
if (K, N) in memo:
return memo[(K, N)]
# for 1 <= i <= N:
# res = min(res,
# max(
# dp(K - 1, i - 1),
# dp(K, N - i)
# ) + 1
# )
res = float('INF')
# 用二分搜索代替线性搜索
lo, hi = 1, N
while lo <= hi:
mid = (lo + hi) // 2
broken = dp(K - 1, mid - 1) # 碎
not_broken = dp(K, N - mid) # 没碎
# res = min(max(碎,没碎) + 1)
if broken > not_broken:
hi = mid - 1
res = min(res, broken + 1)
else:
lo = mid + 1
res = min(res, not_broken + 1)
memo[(K, N)] = res
return res
return dp(K, N)
这个算法的时间复杂度是多少呢?动态规划算法的时间复杂度就是子问题个数 × 函数本身的复杂度。
函数本身的复杂度就是忽略递归部分的复杂度,这里dp
函数中用了一个二分搜索,所以函数本身的复杂度是 O(logN)。
子问题个数也就是不同状态组合的总数,显然是两个状态的乘积,也就是 O(KN)。
所以算法的总时间复杂度是 O(KNlogN), 空间复杂度 O(KN)。效率上比之前的算法 O(KN^2) 要高效不少。
二、重写状态转移
再回顾一下我们之前定义的dp
数组含义:
def dp(k, n) -> int
# 当前状态为 k 个鸡蛋,面对 n 层楼
# 返回这个状态下最少的扔鸡蛋次数
用 dp 数组表示的话也是一样的:
dp[k][n] = m
# 当前状态为 k 个鸡蛋,面对 n 层楼
# 这个状态下最少的扔鸡蛋次数为 m
这种思路下,肯定要穷举所有可能的扔法的,用二分搜索优化也只是做了「剪枝」,减小了搜索空间,但本质思路没有变,只不过是更聪明的穷举。
现在,我们稍微修改dp
数组的定义,确定当前的鸡蛋个数和最多允许的扔鸡蛋次数,就知道能够确定F
的最高楼层数。
dp[k][m] = n
# 当前有 k 个鸡蛋,可以尝试扔 m 次鸡蛋
# 这个状态下,最坏情况下最多能确切测试一栋 n 层的楼
# 比如说 dp[1][7] = 7 表示:
# 现在有 1 个鸡蛋,允许你扔 7 次;
# 这个状态下最多给你 7 层楼,
# 使得你可以确定楼层 F 使得鸡蛋恰好摔不碎
# (一层一层线性探查嘛)
我们最终要求的其实是扔鸡蛋次数m
,但是这时候m
在状态之中而不是dp
数组的结果,可以这样处理:
int superEggDrop(int K, int N) {
int m = 0;
while (dp[K][m] < N) {
m++;
// 状态转移...
}
return m;
}
这种dp
定义基于下面两个事实:
1、无论你在哪层楼扔鸡蛋,鸡蛋只可能摔碎或者没摔碎,碎了的话就测楼下,没碎的话就测楼上。
2、无论你上楼还是下楼,总的楼层数 = 楼上的楼层数 + 楼下的楼层数 + 1(当前这层楼)。
根据这个特点,可以写出下面的状态转移方程:
dp[k][m] = dp[k][m-1] + dp[k-1][m-1] + 1
dp[k][m - 1]
就是楼上的楼层数,因为鸡蛋个数k
不变,也就是鸡蛋没碎,扔鸡蛋次数m
减一;
dp[k - 1][m - 1]
就是楼下的楼层数,因为鸡蛋个数k
减一,也就是鸡蛋碎了,同时扔鸡蛋次数m
减一。
至此,整个思路就完成了,只要把状态转移方程填进框架即可:
int superEggDrop(int K, int N) {
// m 最多不会超过 N 次(线性扫描)
int[][] dp = new int[K + 1][N + 1];
// base case:
// dp[0][..] = 0
// dp[..][0] = 0
// Java 默认初始化数组都为 0
int m = 0;
while (dp[K][m] < N) {
m++;
for (int k = 1; k <= K; k++)
dp[k][m] = dp[k][m - 1] + dp[k - 1][m - 1] + 1;
}
return m;
}
因为我们要求的不是dp
数组里的值,而是某个符合条件的索引m
,所以用while
循环来找到这个m
而已。