动态规划算法问题详解:高楼扔鸡蛋问题的思路分析和优化实现

490 阅读7分钟

这是我参与2022首次更文挑战的第5天,活动详情查看:2022首次更文挑战

基本概念

  • 高楼扔鸡蛋问题:
    • 有一栋1NN层楼,有k个鸡蛋,其中k至少为1
    • 已经去顶这栋楼存在楼层0<=F<=N,0 <= F <= N, 在这层将鸡蛋扔下楼,鸡蛋恰好没碎,即高于F的楼层鸡蛋都会碎,低于F的楼层鸡蛋都不会碎
    • 在最坏的情况下,至少要扔几次鸡蛋,能够确定这样的楼层F
      • 最坏情况: 鸡蛋破碎的情况一定是发生在搜索区间穷尽时
        • 不考虑鸡蛋个数
        • 比如楼层一共有7层,扔鸡蛋最原始的方式就是从第1层开始扔鸡蛋尝试,如果扔到第7层鸡蛋也没有碎,这样的情况就是最坏情况
        • 鸡蛋不会在例如楼层的第一层就会破碎,这不符合最坏的情况
      • 至少: 能确定结果层数尝试的最小次数
        • 不考虑鸡蛋个数
        • 采用二分法就可以确定7层至少要尝试3次,分别为第1+72=4\frac{1 + 7}{2} = 4次,第5+72=6\frac{5 + 7}{2} = 6次,第7次
        • 至少扔的次数为logN\log N向上取整
  • 如果不限制鸡蛋个数,二分法可以得到最少尝试次数.现在给定鸡蛋个数限制为k, 不能直接使用二分法思路

思路分析

  • 动态规划问题框架:
    • 有哪些状态
    • 有哪些选择
    • 穷举计算
  • 高楼扔鸡蛋问题:
    • 状态: 当前拥有的鸡蛋数k和需要测试的楼层数N
      • 随着测试的进行,鸡蛋数会减少,楼层数也会减少,这就是状态的变化
    • 选择: 选择去哪一层扔鸡蛋
      • 有线性扫描和二分思路
      • 不同的选择会造成状态转移
  • 明确了状态和选择之后,就可以形成基本的动态规划的思路:
    • 是一个二维的DP数组或者带有两个状态参数的DP函数来表示状态转移
    • 使用一个for循环来遍历所有选择,使用最优的选择来更新状态
# 当前状态为k个鸡蛋,有N层楼
# 返回这个状态下的最优结果
def dp(K, N):
	int res
	for 1 <= i <= N:
		res = min(res, 在第i层扔鸡蛋)
	return res
  • 状态转移: 在第i层扔鸡蛋,可能会出现两种情况引起状态转移
    • 鸡蛋碎了: 鸡蛋的个数k1, 搜索的楼层区间从 [1,...,N] 变为 [1,...,i-1]. 楼层数为i-1
    • 鸡蛋未碎: 鸡蛋的个数k不变,搜索的楼层区间从 [1,...,N] 变为 [i+1,...,N]. 楼层数为N-i
def dp(k, n):
	int res
	for 1 <= i <= N:
	 res = min(res, min(
	 	# 鸡蛋碎了
	 	dp(K - 1, i - 1),
	 	# 鸡蛋未碎
	 	dp(K, N - i)
	 	# 在第i层扔了1次
	 	) + 1)
	 return res
  • basecase:
    • 当鸡蛋个数K1时,需要线性扫描所有楼层
    • 当楼层层数N0时,不需要扔鸡蛋
def dp(K, N):
	int res
	if K == 1:
		return N
	if N==0:
		return 0
	for 1 <= i <= N:
		res = min(res, 
			# 因为要求的是最坏情况,所以求两者之间的最大值
			max(
				dp(K - 1, i - 1),
				dp(K, N - i)
			) + 1)
	return res
  • 添加备忘录消除重叠子问题:
def supperEggDrop(K : int, N : int):
	memo = dict();
	def dp(K, N) -> int:
		# base case
		if K == 1:
			return N
		if N == 0:
			return 0
		# 如果值已经在备忘录中存在,则直接返回.避免重复计算
		if (K, N) in memo:
			return memo[(K, N)]
		res = float("INF")
		# 穷举所有可能的选择
		for i in range(1, N + 1):
			res = min(res,
				max(
					# 鸡蛋碎了
					dp(K - 1, i - 1),
					# 鸡蛋未碎
					dp(K, N - i)
				) + 1
				)
		# 将结果保存到备忘录
		memo[(K, N)] = res
		return res
	return dp(K, N)
  • 算法时间复杂度: 动态规划算法时间复杂度 = 子问题个数 * 函数本身的复杂度
    • 子问题个数: 不同状态的组合总数,也就是两个状态的乘积.即O(KN)O(KN)
    • 函数本身的复杂度: 在函数中有一个for循环,函数本身的时间复杂度为O(N)O(N)
    • 算法时间复杂度: 两者的乘积O(KN)O(N)=O(KN2)O(KN) * O(N) = O(KN^2)
  • 算法空间复杂度为O(KN)O(KN)

思路优化

  • 动态规划问题要使用好备忘录或者DP Table
  • 高楼扔鸡蛋问题动态规划算法原始思路:
    • 穷举尝试所有楼层1 <= i <= N扔鸡蛋,每次选择尝试次数最少的一层
    • 每次扔鸡蛋的有两种可能:鸡蛋碎了,鸡蛋未碎
    • 如果鸡蛋碎了,楼层应该在i层的下面. 如果鸡蛋未碎,楼层应该在i层的上面
    • 鸡蛋碎未碎,取决于哪种情况下尝试的次数更多,因为题目要求的是最坏的情况下的结果
  • 使用一个for循环遍历楼层1 - N,是在做1次选择:
    • 比如有2个鸡蛋,面对10层楼,在选择去哪一层楼时,就可以将这10层楼全部试一遍
    • 至于下一次怎么选择,会有正确的状态转移,递归会计算出每个选择的代价,选取代价最小的那个就是最优解
def dp(K, N):
	res = float("INF")
	for 1 <= i <= N:
		# 最坏情况下的最少扔鸡蛋的次数
		res = min(res,
				max(
					# 鸡蛋碎了
					dp(K - 1, i - 1),
					# 鸡蛋未碎
					dp(K, N - i)
				) + 1
			)
  • 代码优化:
    • 修改代码中的for循环为二分搜索,可以将算法时间复杂度降低为O(KNlogN)O(KNlogN)
      • 这里的二分解法与二分思路扔鸡蛋没有关系
      • 使用二分解法是因为状态转移的函数具有单调性,可以使用二分搜索法快速找到最值
    • 改进动态规划的解法可以进一步时间复杂度降为O(KN)O(KN)
    • 再使用数学方法,可以将时间复杂度达到最优O(KlogN)O(KlogN), 空间复杂度达到最优O(1)O(1)

二分查找优化

  • 二分查找优化的核心在于状态转移方程的单调性
  • 首先根据dp(K,N)dp(K, N) 数组的定义:有K个鸡蛋面对N层楼,最少需要扔几次.当K固定时,这个函数一定是单调递增的
    • 这样函数dp(K1,i1)dp(K-1,i-1)和函数dp(K,Ni)dp(K,N-i):其中i是从1N单调递增的,如果固定KN, 将这两个函数看作是i的函数,那么函数dp(K1,i1)dp(K-1, i - 1)是单调递增的,函数dp(K,Ni)dp(K,N-i)是单调递减的
    • 此时,若要求两者中的较大值,再求这些最大值中的最小值,就是求这两个函数的交点.可以使用二分法快速查找
  • 二分查找可以用于优化以下形式的for循环代码:
for (int i = 0; i < n; i++) {
	if (isOk(i)) {
		return i;
	}
}
  • 这里要求的是两个函数的交点:
if (int i = 1; i <= N; i++) {
	if (dp(K - 1, i - 1) == dp(K, N - i)) {
		return dp(K - 1, i - 1);
	}
}
def superEggDrop(K : int, N : int) -> int:
	memo = dict()
	def dp(K, N):
		if K == 1:
			return N
		if N == 0:
			return 0
		if (K, N) in memo:
			return memo[(K, N)]
		# 用二分搜索代替线性搜索
		low, high = 1, N
		while low < high:
			mid = low + (high - low) / 2
			broken = dp(K - 1, mid - 1)
			not_broken = dp(K, N - mid)
			# res = res + min(max(broken, not_broken) + 1)
			if broken > not_broken:
				high = mid - 1
				res = min(res, broken + 1)
			else:
				low = mid + 1
				res = min(res, not_broken + 1)
		memo[(K, N)] = res
		return res
	return dp(K, N)
  • 二分查找优化的算法时间复杂度: 函数本身的时间复杂度 * 子问题个数
    • 函数本身时间复杂度: 忽略递归部分的算法时间复杂度,使用了二分查找优化的dp函数的复杂度是O(logN)O(logN)
    • 子问题个数: 不同状态组合的总数,即两个状态的乘积,也就是O(KN)
    • 算法时间复杂度: O(KNlogN)=O(KN)O(logN)O(KNlogN) = O(KN) * O(logN)
  • 空间复杂度为O(KN)

重新定义状态转移

  • 动态规划状态转移中不同的状态定义会有不同的解法
  • 原始dp数组定义:
dpf dp(K, N) -> int
# 当前状态为K个鸡蛋,面对n层楼
# 返回这个状态下最少扔鸡蛋的次数

dp[K, N] = m
# 使用dp数组表示的含义
# 当前状态为K个鸡蛋,面对n层楼
# 返回这个状态下最少扔鸡蛋的次数为m
  • 按照这个dp数组定义:
    • 就是确定当前鸡蛋的个数和面对的层数,就知道最小扔鸡蛋的次数.最终答案就是dp(K, N) 的结果
    • 在这种定义的前提下,肯定要穷举所有可能的扔法,二分优化也只是做了剪枝,减小的搜索区间,但是本质还是穷举的思想
  • 重新定义状态转移: dp数组的定义可以为,确定当前鸡蛋的个数K以及最多允许扔鸡蛋的次数m, 就能确定N的最高楼层数
dp[K][m] = N
# 当前有k个鸡蛋,可以尝试扔m次鸡蛋
# 在这样的条件下,最坏的情况能够测试N层楼
  • 最终要求的是扔鸡蛋的次数m, 但是mdp数组状态之中而不是dp数组的结果,可以对数组这样处理:
int superEggDrop(int K, int N) {
	int m = 0;
	/*
	 * while循环的结束条件 : dp[K][m] = N
	 * 	- 给定的K个鸡蛋,测试m次,最坏的情况下最多能测试N层楼
	 */
	while (dp[K][m] < N) {
		m++
		// 状态转移
	} 
	return m;
}
  • 状态转移方程:
    • dp数组的定义有以下两种特点:
      • 无论在哪一层扔鸡蛋,鸡蛋只可能摔碎或者没摔碎,碎了就测楼下,没碎就测楼上
      • 无论是上楼还是下楼,总楼层 = 楼上的层数 + 楼下的层数 + 1
    • 根据dp数组特点,得出状态转移方程: dp[k][m]=dp[k][m1]+dp[k1][m1]+1dp[k][m]=dp[k][m-1]+dp[k-1][m-1]+1
      • dp[K][m-1]: 楼上的层数. 因为鸡蛋没碎,所以K值不变,同时m值减1
      • dp[K-1][m-1]: 楼下的层数. 因为鸡蛋碎了,所以K值减1, 同时m值减1
      • 因为定义的m的含义是一个允许的次数上界,所以这里是减1而不是加1
int superEggDrop(int K, int N) {
	// 定义dp数组.dp数组m值不会超过楼层数N
	int[][] dp = new int[K + 1][N + 1];
	// base case
	dp[0][N + 1] = 0;
	dp[K + 1][0] = 0;
	int m = 0;
	/*
	 * while中的代码等价于:
	 * 	for (int m = 1; dp[K][m] < N; m++) {
	 * 		for (int k = 1; k <= K; k++) {
	 * 			dp[K][m] = dp[K][m - 1] + dp[K - 1][m - 1] + 1
	 * 		}
	 * 	}
	 */
	while (dp[K][m]) {
		m++;
		for (int k = 1; k <= K; k++) {
			dp[K][m] = dp[K][m - 1] + dp[K - 1][m - 1] + 1;
		}
	}
	return m;
}
  • 算法时间复杂度即为两个循环的乘积O(KN)O(KN)
  • dp[K][m] 状态转移只和左边和左上的两个状态有关,所以很容易可以优化为一维数组
  • 可以继续使用二分搜索代替线性扫描获取m的值:
    • 注意函数dp(K, m) 是随着m单调递增的,因为鸡蛋不变时,允许测试的次数m越多,可以测试的楼层数N就越高
    • 这样就可以根据二分搜索算法快速逼近dp(K,m)==Ndp(K,m)==N这个终止条件,时间复杂度下降为O(KlogN)O(KlogN)
// 将线性搜索修改为二分搜索
// for (int m = 1; dp[K][m] < N; m++)
int low = 1, high = N;
while (low < high) {
	int mid = low + (high - low) / 2;
	if (mid < N) {
		low = mid + 1;
	} else {
		high = mid - 1;
	}
	for (int k = 1; k <= K; k++) {
		dp[K][mid] = dp[K][m - 1] + dp[K - 1][m - 1] + 1
	}
}

总结

  • 第一种二分搜索优化是利用了dp函数的单调性,使用二分技巧快速搜索最优状态
  • 第二种优化巧妙地修改了状态转移方程,简化了求解流程,但是思考逻辑较为复杂,再通过一些数据方法和二分搜索可以进一步优化第二种搜索解法