题目描述
思路
套用数位DP代码,这个模板之前讲过,请见我的2376. 统计特殊整数题解。
简单回顾一下这个模板,模板采用的是填充数位的思想,需要记录idx下标,is_limit状态位表示是否顶到了右区间极限,is_num表示是否在填充前导零。剩下的dfs状态就和具体题目要求有关系了。
但是这个题有个比较细节的地方我给你讲一下,我原先的代码是这样的,你从下面的代码观察到,在dfs我传入了even数位上的累计值和odd数位上的累计值。 从道理上讲是没问题的,但是sum(even)和sum(odd)的状态组合非常大,深搜会极其消耗时间,所以下面的代码会超时。
我们退一步想,我们最终填充完毕之后,只需要考虑|even-odd| == 0,所以本质上你没必要一直计算两个累计值,只需要把diff:=even-odd传给dfs就好了。所以你需要将代码中even和odd状态更换成diff这一个状态,这将状态空间从“二维面积”降低到了“一维线性”,极大减少了计算量。
TLE代码
class Solution:
def countBalanced(self, low: int, high: int) -> int:
@cache
def dfs(idx : int , is_limit : bool , is_num : bool , even : int , odd : int , bound : int) -> int :
if idx == len(str(bound)) :
return even == odd
res = 0
if not is_num :
res = dfs(idx + 1 , False , False , 0 , 0 , bound)
up = int(str(bound)[idx]) if is_limit else 9
for d in range(1 - int(is_num) , up + 1) :
if idx % 2 == 0 :
res += dfs(idx + 1 , is_limit and d == up , True , even , odd + d, bound)
else :
res += dfs(idx + 1 , is_limit and d == up , True , even + d , odd, bound)
return res
return dfs(0 , True , False , 0 , 0 , high) - dfs(0 , True , False , 0 , 0 ,low - 1)
代码
无剪枝优化
class Solution:
def countBalanced(self, low: int, high: int) -> int:
# diff := even - odd
@cache
def dfs(idx : int , is_limit: bool , is_num : bool , diff: int , bound : str) :
if len(bound) == idx :
return diff == 0
res = 0
if not is_num:
res = dfs(idx + 1, False , False , 0 , bound)
up = int(bound[idx]) if is_limit else 9
for d in range( 1 - int(is_num) , up + 1) :
if idx % 2 == 0 :
res += dfs(idx + 1, is_limit and d == up , True , diff - d , bound)
else :
res += dfs(idx + 1, is_limit and d == up , True , diff + d , bound)
return res
return dfs(0,True, False , 0 , str(high)) - dfs(0,True,False,0,str(low-1))
含剪枝优化: 字符串长度信息与dfs形成闭包,避免dfs的时候把不变量传进去
class Solution:
def countBalanced(self, low: int, high: int) -> int:
def cal(s:str) :
# diff := even - odd
n = len(s)
@cache
def dfs(idx : int , is_limit: bool , is_num : bool , diff: int) :
if n == idx :
return diff == 0
# 有效剪枝
if abs(diff) > (n - idx) * 9 :
return 0
res = 0
if not is_num:
res = dfs(idx + 1, False , False , 0 )
up = int(s[idx]) if is_limit else 9
for d in range( 1 - int(is_num) , up + 1) :
if idx % 2 == 0 :
res += dfs(idx + 1, is_limit and d == up , True , diff - d )
else :
res += dfs(idx + 1, is_limit and d == up , True , diff + d )
return res
return dfs(0,True, False , 0 )
return cal(str(high)) - cal(str(low-1))