def solution(s: str) -> int:
# write code here
MOD = 10**9 + 7
n = len(s)
# 计算问号的位置
q_positions = []
for i in range(n):
if s[i] == '?':
q_positions.append(i)
k = len(q_positions) # 问号的数量
total_weight = 0
# 枚举所有可能的组合
for mask in range(1 << k):
# 构造当前组合对应的字符串
current = list(s)
for i in range(k):
# 根据mask的第i位决定问号替换成0还是1
current[q_positions[i]] = '1' if (mask & (1 << i)) else '0'
# 计算当前字符串的权值
weight = 0
for i in range(n - 1):
if current[i] != current[i + 1]:
weight += 1
total_weight = (total_weight + weight) % MOD
return total_weight
# return 0 # placeholder return
if __name__ == '__main__':
print(solution("01?") == 3)
print(solution("1??0") == 6)
print(solution("?") == 0)
思路:
-
问题理解
- 输入是一个包含 '0'、'1' 和 '?' 的字符串
- 每个 '?' 可以被替换成 '0' 或 '1'
- 字符串的权值定义为:相邻不同字符对的数量
- 需要计算所有可能字符串的权值之和
-
关键观察
- 如果有 k 个问号,就有 2^k 种可能的字符串
- 每个问号有两种-选择(0或1)
- 这自然联想到可以用二进制数来表示所有可能的组合
-
位运算的应用
例如对于字符串 "1??0":
k = 2(两个问号) mask从0到3 (二进制:00, 01, 10, 11)mask = 0 (00)表示两个问号都替换成0:1000mask = 1 (01)表示第一个?替换成0,第二个替换成1:1010mask = 2 (10)表示第一个?替换成1,第二个替换成0:1100mask = 3 (11)表示两个问号都替换成1:1110 -
具体实现步骤
a. 找出问号位置
- 遍历字符串,记录所有问号的位置
- 这些位置之后会用来替换字符
q_positions = [] for i in range(n): if s[i] == '?': q_positions.append(i)b. 枚举所有可能组合
- 使用一个循环从0到2^k-1
- 每个数字代表一种问号替换方案
for mask in range(1 << k):# 1 << k 等价于 2^kc. 构造具体字符串
- 将原字符串转换为列表以便修改
- 对每个问号位置,根据mask的对应位决定替换成'0'还是'1'
mask & (1 << i)用来检查mask的第i位是否为1
current = list(s) for i in range(k): current[q_positions[i]] = '1' if (mask & (1 << i)) else '0'd. 计算权
- 遍历相邻字符对
- 当相邻字符不同时,权值加1
weight = 0 for i in range(n - 1): if current[i] != current[i + 1]: weight += 1 -
示例分析 以 "1??0" 为例:
原串:1??0 可能的组合:
mask=0 (00): 1000 权值=1 (一对不同)mask=1 (01): 1010 权值=2 (两对不同)mask=2 (10): 1100 权值=1 (一对不同)mask=3 (11): 1110 权值=2 (两对不同) 总权值:1+2+1+2=6 -
优化思考
- 时间复杂度:O(2^k * n),其中k是问号数量,n是字符串长度
- 空间复杂度:O(n),主要用于存储当前处理的字符串
- 对于问号数量较多的情况,可能需要考虑更优的解法
- 对于特殊情况(如全是问号或没有问号)可以添加快速判断
-
模运算处理
- 所有累加操作都需要对10^9 + 7取模
- 这是为了防止数值溢出,同时保持结果的正确性