问题分析
题目描述
-
有n道菜,每道菜有两个属性:
- 价格a[i]
- 是否含蘑菇s[i](1表示含有,0表示不含)
-
需要选择k道菜
-
含蘑菇的菜不能超过m道
-
目标:在满足条件的情况下求最小总价格
关键约束
- 选菜数量必须等于k
- 含蘑菇的菜品数量必须 ≤ m
- 如果无法满足条件,返回-1
解法一:动态规划(自底向上)
思路
-
状态定义:
- dp[i][j][l] 表示考虑前i个菜,选择了j道菜,其中含l个蘑菇的最小花费
-
状态转移:
-
对于第i道菜,有两种选择:
-
不选:dp[i+1][j][l] = min(dp[i+1][j][l], dp[i][j][l])
-
选:
- 如果含蘑菇(s[i]='1'):dp[i+1][j+1][l+1] = min(dp[i+1][j+1][l+1], dp[i][j][l] + a[i])
- 如果不含蘑菇(s[i]='0'):dp[i+1][j+1][l] = min(dp[i+1][j+1][l], dp[i][j][l] + a[i])
-
-
-
初始化:
- dp[0][0][0] = 0
- 其他状态初始化为正无穷
-
最终结果:
- min(dp[n][k][l]) 其中 l∈[0,m]
- 如果结果是正无穷,返回-1 ###代码
def solution(s: str, a: list, m: int, k: int) -> int:
n = len(s)
dp = [[[float('inf')] * (m + 1) for _ in range(k + 1)] for _ in range(n + 1)]
dp[0][0][0] = 0
# l个蘑菇
for i in range(n):
for j in range(k + 1):
for l in range(m + 1):
if dp[i][j][l] == float('inf'):
continue
dp[i + 1][j][l] = min(dp[i + 1][j][l], dp[i][j][l])
if j < k:
if s[i] == '1' and l < m:
dp[i + 1][j + 1][l + 1] = min(dp[i + 1][j + 1][l + 1], dp[i][j][l] + a[i])
elif s[i] == '0':
dp[i + 1][j + 1][l] = min(dp[i + 1][j + 1][l], dp[i][j][l] + a[i])
result = min(dp[n][k][l] for l in range(m + 1))
return result if result != float('inf') else -1
if __name__ == '__main__':
print(solution("001", [10, 20, 30], 1, 2) == 30)
print(solution("111", [10, 20, 30], 1, 2) == -1)
print(solution("0101", [5, 15, 10, 20], 2, 3) == 30)
复杂度分析
- 时间复杂度:O(n×k×m)
- 空间复杂度:O(n×k×m)
解法二:DFS(自顶向下)
思路
-
递归函数参数:
- idx:当前考虑的菜品索引
- cur_m:当前选择的含蘑菇菜品数量
- cur_k:当前选择的总菜品数量
- cur_w:当前总价格
-
递归终止条件:
- 已选够k道菜:返回当前总价格
- 索引越界或含蘑菇菜品超过m:返回正无穷
-
递归过程:
-
对每道菜都有两种选择:
- 选择当前菜
- 不选择当前菜
-
返回两种选择的最小值
-
剪枝优化
- 特判:如果所有菜都含蘑菇且m<n,直接返回-1
- 当cur_m>m时提前返回
代码
def solution(s: str, a: list, m: int, k: int) -> int:
n = len(s)
if s == '1' * n and m < n: # 如果所有菜都含有蘑菇且 m < n,直接返回 -1
return -1
def dfs(idx: int, cur_m: int, cur_k: int, cur_w: int):
if cur_k == k: # 如果已经选了 k 道菜
return cur_w
if idx >= n or cur_m > m: # 如果已经遍历完所有菜或蘑菇菜超过 m
return float('inf')
# 选择当前菜
include = dfs(idx + 1, cur_m + (s[idx] == '1'), cur_k + 1, cur_w + a[idx])
# 不选择当前菜
exclude = dfs(idx + 1, cur_m, cur_k, cur_w)
return min(include, exclude)
result = dfs(0, 0, 0, 0)
return result if result != float('inf') else -1
if __name__ == '__main__':
print(solution("001", [10, 20, 30], 1, 2) == 30)
print(solution("111", [10, 20, 30], 1, 2) == -1)
print(solution("0101", [5, 15, 10, 20], 2, 3) == 30)
复杂度分析
- 时间复杂度:O(2^n),每个菜品都有选择和不选择两种状态
- 空间复杂度:O(n),递归栈的深度