灵神数位DP题目模板及训练

19 阅读16分钟

1 介绍

本专题用来记录灵神数位DP题目模板及训练。

2 数位DP模板及例题讲解

模板题目12376. 统计特殊整数

讲解如下:

定义递归函数f(i, mask, is_limit, is_num),它表示已使用的数字集合为mask,构造第i位及其之后数位的合法方案数。

(1)i表示第i位。

(2)mask表示已使用的数字集合。其二进制表示中第i位为1,则表示数字i已使用。

(3)is_limit表示当前第i位是否受到了n的约束,注意要构造的数字不能超出n

(4)is_num表示第i位前面的数位是否填了数字。

完整的递归函数如下,

def f(i: int, mask: int, is_limit: bool, is_num: bool) -> int:
    if i == len(s): #s为最大值n的字符串表示
        return int(is_num)
    res = 0
    if not is_num:
        res += f(i+1, mask, False, False)
    low = 0 if is_num else 1
    up = int(s[i]) if is_limit else 9
    for d in range(low, up+1):
        if (mask >> d & 1) == 0:
            res += f(i+1, mask | (1 << d), is_limit and d == up, True)
    return res

注意手写记忆化时,我们只保存“是数字且不受最大值限制的状态”,即is_num == true && is_limit == false下的f(i, mask)的值。is_limit == true情况下,不能使用缓存值,需要重新计算f(i, mask, is_limit, is_num)

C++代码如下,

class Solution {
public:
    int countSpecialNumbers(int n) {
        string s = to_string(n);
        int m = s.size();
        vector<vector<int>> dp(m, vector<int>(1<<10,-1));
        function<int(int,int,bool,bool)> f =[&] (int i, int mask, bool is_limit, bool is_num) -> int {
            if (i == m) {
                return is_num;
            }
            if (is_num && !is_limit && dp[i][mask] != -1) {
                return dp[i][mask];
            }
            int res = 0;
            if (!is_num) {
                res += f(i+1, mask, false, false);
            }
            int low = is_num ? 0 : 1;
            int up = is_limit ? s[i]-'0' : 9;
            for (int d = low; d <= up; ++d) {
                if ((mask >> d & 1) == 0) {
                    res += f(i+1, mask | (1 << d), is_limit && d == up, true);
                }
            }
            if (is_num && !is_limit) {
                dp[i][mask] = res;
            }
            return res;
        };
        return f(0, 0, true, false);
    }
};

python3代码如下,

class Solution:
    def countSpecialNumbers(self, n: int) -> int:
        s = str(n)
        @lru_cache
        def f(i: int, mask: int, is_limit: bool, is_num: bool) -> int:
            if i == len(s): #s为最大值n的字符串表示
                return int(is_num)
            res = 0
            if not is_num:
                res += f(i+1, mask, False, False)
            low = 0 if is_num else 1
            up = int(s[i]) if is_limit else 9
            for d in range(low, up+1):
                if (mask >> d & 1) == 0:
                    res += f(i+1, mask | (1 << d), is_limit and d == up, True)
            return res     
        return f(0, 0, True, False)   

模板题目22999. 统计强大整数的数目

讲解如下:定义递归函数dfs(i, is_low_limit, is_high_limit, is_num),它表示构造第i位及其之后数位的合法方案数。

其中,

(1)i:表示第i位。

(2)is_low_limit:表示第i位之前的数位是否与范围最小值的第i位之前的部分一致。如果一致则为true,表示受到了范围最小值的约束。

(3)is_high_limit:表示第i位之前的数位是否与范围最大值的第i位之前的部分一致。如果一致则为true,表示受到了范围最大值的约束。

(4)is_num:表示第i位之前的数位是否填了非零值,如果填了非零值则为true

完整的递归函数如下,

@cache
def dfs(i: int, is_low_limit: bool, is_high_limit: bool, is_num: bool) -> int:
    if i == n:
        return int(is_num)
    res = 0
    if not is_num and start[i] == '0':
        if i < diff:
            res += dfs(i+1, True, False, False)

    #根据is_low_limit和is_high_limit得到的数位的范围
    low = int(start[i]) if is_low_limit else 0
    up = int(finish[i]) if is_high_limit else 9 

    #根据is_num和limit得到的数位的范围
    low1 = 0 if is_num else 1 
    up1 = limit 

    dmin = max(low, low1)
    dmax = min(up, up1)

    if i < diff:
        for d in range(dmin,dmax+1):
            res += dfs(i+1, is_low_limit and d == low, is_high_limit and d == up, True)
    else:
        x = int(s[i-diff])
        if dmin <= x <= dmax:
            res = dfs(i+1, is_low_limit and x == low, is_high_limit and x == up, True)
    return res 

注意手写记忆化时,我们“只保存是数字,同时不受最值限制的状态”,即is_num && !is_low_limit && !is_high_limit。这时只需要记忆化一维状态i

由于一般,最小值大于等于0,因此可以把is_num去掉,使用非is_num版本。该版本的递归函数如下,

@cache
def dfs(i: int, is_low_limit: bool, is_high_limit: bool) -> int:
    if i == n:
        return 1
    res = 0

    #根据is_low_limit和is_high_limit得到的数位的范围
    low = int(start[i]) if is_low_limit else 0
    up = int(finish[i]) if is_high_limit else 9 

    #根据limit得到的数位的范围
    low1 = 0 
    up1 = limit 

    dmin = max(low, low1)
    dmax = min(up, up1)

    if i < diff:
        for d in range(dmin,dmax+1):
            res += dfs(i+1, is_low_limit and d == low, is_high_limit and d == up)
    else:
        x = int(s[i-diff])
        if dmin <= x <= dmax:
            res = dfs(i+1, is_low_limit and x == low, is_high_limit and x == up)
    return res 

C++代码如下,

class Solution {
public:
    long long numberOfPowerfulInt(long long start, long long finish, int limit, string s) {
        string s_start = to_string(start);
        string s_finish = to_string(finish);
        int n = s_finish.size();
        int diff = n - s.size();
        s_start = string(n-s_start.size(), '0') + s_start;

        vector<long long> memo(20, -1); //只保存数字,且不受限制的状态

        function<long long(int,bool,bool,bool)> dfs =[&] (int i, bool is_low_limit, bool is_high_limit, bool is_num) -> long long {
            if (i == n) {
                return is_num;
            }
            if (!is_low_limit && !is_high_limit && is_num && memo[i] != -1) {
                return memo[i];
            }            
            long long res = 0;
            if (!is_num && s_start[i] == '0') {
                if (i < diff) {
                    res += dfs(i+1, true, false, false);
                }
            }

            int low = is_low_limit ? s_start[i]-'0' : 0;
            int up = is_high_limit ? s_finish[i]-'0': 9;

            int low1 = is_num ? 0 : 1;
            int up1 = limit;

            int dmin = max(low, low1);
            int dmax = min(up, up1);

            if (i < diff) {
                for (int d = dmin; d < dmax + 1; ++d) {
                    res += dfs(i+1, is_low_limit && d == low, is_high_limit && d == up, true);
                }
            } else {
                int x = s[i-diff] - '0';
                if (dmin <= x && x < dmax+1) {
                    res = dfs(i+1, is_low_limit && x == low, is_high_limit && x == up, true);
                }
            }
            if (!is_low_limit && !is_high_limit && is_num) {
                memo[i] = res;
            }
            return res;
        };
        return dfs(0, true, true, false);
    }
};

python3代码如下,

class Solution:
    def numberOfPowerfulInt(self, start: int, finish: int, limit: int, s: str) -> int:
        start = str(start)
        finish = str(finish)
        n = len(finish)
        start = "0" * (n - len(start)) + start 
        diff = n - len(s)
        @lru_cache
        def dfs(i: int, is_low_limit: bool, is_high_limit: bool, is_num: bool) -> int:
            if i == n:
                return int(is_num)
            res = 0
            if not is_num and start[i] == '0':
                if i < diff:
                    res += dfs(i+1, True, False, False)
            
            #根据is_low_limit和is_high_limit得到的数位的范围
            low = int(start[i]) if is_low_limit else 0
            up = int(finish[i]) if is_high_limit else 9 

            #根据is_num和limit得到的数位的范围
            low1 = 0 if is_num else 1 
            up1 = limit 

            dmin = max(low, low1)
            dmax = min(up, up1)

            if i < diff:
                for d in range(dmin,dmax+1):
                    res += dfs(i+1, is_low_limit and d == low, is_high_limit and d == up, True)
            else:
                x = int(s[i-diff])
                if dmin <= x <= dmax:
                    res = dfs(i+1, is_low_limit and x == low, is_high_limit and x == up, True)
            return res 
        return dfs(0, True, True, False)

3 训练

题目12719. 统计整数数目

解题思路:带最小值限制的数位dp。

C++代码如下,

class Solution {
public:
    int count(string num1, string num2, int min_sum, int max_sum) {
        int n = num2.size();
        num1 = string(n-num1.size(),'0') + num1;
        const int mod = 1e9 + 7;

        int memo[25][410];
        memset(memo, -1, sizeof memo);

        function<int(int,int,bool,bool)> dfs =[&] (int i, int curr_sum, bool is_low_limit, bool is_high_limit) -> int {
            if (!is_low_limit && !is_high_limit && memo[i][curr_sum] != -1) {
                return memo[i][curr_sum];
            }
            if (i == n) {
                return curr_sum >= min_sum;
            }
            int res = 0;
            int low = is_low_limit ? num1[i]-'0' : 0;
            int up = is_high_limit ? num2[i]-'0' : 9;
            for (int d = low; d < up+1; ++d) {
                if (d + curr_sum <= max_sum) {
                    res += dfs(i+1, curr_sum+d, is_low_limit && d == low, is_high_limit && d == up);
                    res %= mod;
                }
            }
            if (!is_low_limit && !is_high_limit) {
                memo[i][curr_sum] = res;
            }
            
            return res % mod;
        };
        return dfs(0, 0, true, true);
    }
};

python3代码如下,

#is_num版本
class Solution:
    def count(self, num1: str, num2: str, min_sum: int, max_sum: int) -> int:
        n = len(num2)
        num1 = "0"*(n-len(num1)) + num1 

        @cache 
        def dfs(i: int, curr_sum: int, is_low_limit: bool, is_high_limit: bool, is_num: bool) -> int:
            if i == n:
                return is_num and curr_sum >= min_sum 
            
            res = 0
            if not is_num and num1[i] == '0':
                res += dfs(i+1, 0, True, False, False)

            low = int(num1[i]) if is_low_limit else 0 
            up = int(num2[i]) if is_high_limit else 9 

            low1 = 0 if is_num else 1
            up1 = up 

            dmin = max(low, low1)
            dmax = min(up, up1)

            for d in range(dmin,dmax+1):
                if curr_sum+d <= max_sum:
                    res += dfs(i+1, curr_sum+d, is_low_limit and d == low, is_high_limit and d == up, True)
            return res  
        return dfs(0, 0, True, True, False) % int(1e9+7)
            
#非is_num版本
class Solution:
    def count(self, num1: str, num2: str, min_sum: int, max_sum: int) -> int:
        s_num1 = str(num1)
        s_num2 = str(num2)
        n = len(s_num2)
        s_num1 = "0"*(n-len(s_num1)) + s_num1 

        @cache #使用@lru_cache会超时
        def dfs(i: int, curr_sum: int, is_low_limit: bool, is_high_limit: bool) -> int:
            if curr_sum > max_sum:
                return 0
            if i == n:
                return curr_sum >= min_sum
            res = 0            
            low = int(s_num1[i]) if is_low_limit else 0
            up = int(s_num2[i]) if is_high_limit else 9

            dmin = low
            dmax = up

            for d in range(dmin,dmax+1):
                res += dfs(i+1, d+curr_sum, is_low_limit and d == low, is_high_limit and d == up)
            return res 
        return dfs(0, 0, True, True) % int(1e9+7)

题目2788. 旋转数字

解题思路:不带最小值限制的数位DP。

C++代码如下,

class Solution {
public:
    int rotatedDigits(int n) {
        string s = to_string(n);
        n = s.size();

        int memo[7];
        memset(memo, -1, sizeof memo);

        function<int(int,bool,bool,bool)> dfs =[&] (int i, bool is_diff, bool is_limit, bool is_num) -> int {
            if (!is_limit && is_num && is_diff && memo[i] != -1) {
                return memo[i];
            }
            if (i == n) {
                return is_num && is_diff;
            } 
            int res = 0;
            if (!is_num) {
                res += dfs(i+1, false, false, false);
            }
            int low = is_num ? 0 : 1;
            int up = is_limit ? s[i]-'0' : 9;

            for (int d = low; d < up+1; ++d) {
                if (d == 0 || d == 1 || d == 8) {
                    res += dfs(i+1, is_diff, is_limit && d == up, true);
                } else if (d == 2 || d == 5 || d == 6 || d == 9) {
                    res += dfs(i+1, true, is_limit && d == up, true);
                }
            }
            if (!is_limit && is_num && is_diff) {
                memo[i] = res;
            }
            return res;
        };
        return dfs(0, false, true, false);
    }
};

python3代码如下,

class Solution:
    def rotatedDigits(self, n: int) -> int:
        s = str(n)
        n = len(s)

        @cache 
        def dfs(i: int, is_diff: bool, is_limit: bool, is_num: bool) -> int:
            if i == n:
                return is_num and is_diff 
            res = 0
            if not is_num:
                res += dfs(i+1, False, False, False)
            
            low = 0 if is_num else 1
            up = int(s[i]) if is_limit else 9 

            for d in range(low,up+1):
                if d in [0,1,8]:
                    res += dfs(i+1, is_diff, is_limit and d == up, True)
                elif d in [2,5,6,9]:
                    res += dfs(i+1, True, is_limit and d == up, True)
            return res 
        return dfs(0, False, True, False)

题目3902. 最大为 N 的数字组合

解题思路:非下界的数位DP。

C++代码如下,

class Solution {
public:
    int atMostNGivenDigitSet(vector<string>& digits, int n) {
        string s = to_string(n);
        n = s.size();

        int memo[12];
        memset(memo, -1, sizeof memo);

        function<int(int,bool,bool)> dfs =[&] (int i, bool is_limit, bool is_num) -> int {
            if (!is_limit && is_num && memo[i] != -1) {
                return memo[i];
            }
            if (i == n) {
                return is_num;
            }
            int res = 0;
            if (!is_num) {
                res += dfs(i+1, false, false);
            }
            int low = is_num ? 0 : 1;
            int up = is_limit ? s[i]-'0' : 9;
            
            for (string x : digits) {
                int d = stoi(x);
                if (low <= d && d <= up) {
                    res += dfs(i+1, is_limit && d == up, true);
                }
            }
            if (!is_limit && is_num) {
                memo[i] = res;
            }
            return res;
        };
        return dfs(0, true, false);
    }
};

python3代码如下,

class Solution:
    def atMostNGivenDigitSet(self, digits: List[str], n: int) -> int:
        s = str(n)
        n = len(s)

        @cache
        def dfs(i: int, is_limit: bool, is_num: bool) -> int:
            if i == n:
                return is_num 
            res = 0
            if not is_num:
                res += dfs(i+1, False, False)
            low = 0 if is_num else 1
            up = int(s[i]) if is_limit else 9 

            for d in digits:
                d = int(d)
                if low <= d <= up:
                    res += dfs(i+1, is_limit and d == up, True) 
            return res 
        return dfs(0, True, False)  
        

题目4233. 数字 1 的个数

解题思路:非下界限制的数位DP。

C++代码如下,

class Solution {
public:
    int countDigitOne(int n) {
        string s = to_string(n);
        n = s.size();

        int memo[12][12];
        memset(memo, -1, sizeof memo);

        function<int(int,int,bool,bool)> dfs =[&] (int i, int curr, bool is_limit, bool is_num) -> int {
            if (!is_limit && is_num && memo[i][curr] != -1) {
                return memo[i][curr];
            }
            if (i == n) {
                return curr;
            }
            int res = 0;
            if (!is_num) {
                res += dfs(i+1, 0, false, false);
            }

            int low = is_num ? 0 : 1;
            int up = is_limit ? s[i]-'0' : 9;

            for (int d = low; d < up+1; ++d) {
                res += dfs(i+1, curr+(d==1), is_limit && d == up, true);
            }
            if (!is_limit && is_num) {
                memo[i][curr] = res;
            }
            return res;
        };

        return dfs(0, 0, true, false);
    }
};

python3代码如下,

class Solution:
    def countDigitOne(self, n: int) -> int:
        s = str(n)
        n = len(s) 

        @cache 
        def dfs(i: int, curr: int, is_limit: bool, is_num: bool) -> int:
            if i == n:
                return curr  

            res = 0
            if not is_num:
                res += dfs(i+1, 0, False, False) 
            
            low = 0 if is_num else 1 
            up = int(s[i]) if is_limit else 9 

            for d in range(low,up+1):
                res += dfs(i+1, curr+(d==1), is_limit and d == up, True) 
            return res 
        return dfs(0, 0, True, False)

题目5面试题 17.06. 2出现的次数

解题思路:非最小值限制的数位DP。

C++代码如下,

class Solution {
public:
    int numberOf2sInRange(int n) {
        string s = to_string(n);
        n = s.size();

        int memo[12][12];
        memset(memo, -1, sizeof memo);

        function<int(int,int,bool,bool)> dfs =[&] (int i, int curr, bool is_limit, bool is_num) -> int {
            if (!is_limit && is_num && memo[i][curr] != -1) {
                return memo[i][curr];
            }
            if (i == n) {
                return curr;
            }
            int res = 0;
            if (!is_num) {
                res += dfs(i+1, 0, false, false);
            }
            int low = is_num ? 0 : 1;
            int up = is_limit ? s[i]-'0' : 9;
            for (int d = low; d < up+1; ++d) {
                res += dfs(i+1, curr+(d==2), is_limit && d == up, true);
            }
            if (!is_limit && is_num) {
                memo[i][curr] = res;
            }
            return res;
        };
        return dfs(0, 0, true, false);
    }
};

python3代码如下,

class Solution:
    def numberOf2sInRange(self, n: int) -> int:
        s = str(n)
        n = len(s) 

        @cache 
        def dfs(i: int, curr: int, is_limit: bool, is_num: bool) -> int:
            if i == n:
                return curr 
            res = 0
            if not is_num:
                res += dfs(i+1, 0, False, False) 
            low = 0 if is_num else 1
            up = int(s[i]) if is_limit else 9 

            for d in range(low,up+1):
                res += dfs(i+1, curr+(d==2), is_limit and d == up, True) 
            return res 
        return dfs(0, 0, True, False)  

题目6600. 不含连续1的非负整数

解题思路:二进制表示。0是合法方案数,此时不需要is_num。非最小值限制的数位DP。

C++代码如下,

class Solution {
public:
    int findIntegers(int n) {
        bitset<32> binx(n);
        string s = binx.to_string();
        n = s.size();

        int memo[35][2];
        memset(memo, -1, sizeof memo);

        function<int(int,int,bool)> dfs =[&] (int i, int prev, bool is_limit) -> int {
            if (!is_limit && memo[i][prev] != -1) {
                return memo[i][prev];
            }
            if (i == n) {
                return 1;
            }
            int res = 0;
            int low = 0;
            int up = is_limit ? s[i]-'0' : 1;

            for (int d = low; d < up+1; ++d) {
                if (i != 0 && prev == 1 && d == 1) {
                    //pass
                } else {
                    res += dfs(i+1, d, is_limit && d == up);
                }
            }
            if (!is_limit) {
                memo[i][prev] = res;
            }
            return res;
        };
        return dfs(0, 0, true);
    }
};

python3代码如下,

class Solution:
    def findIntegers(self, n: int) -> int:
        s = bin(n)[2:] #n的二进制标志
        n = len(s) 

        @cache
        def dfs(i: int, prev: int, is_limit: bool) -> int:
            if i == n:
                return 1
            res = 0
            low = 0
            up = int(s[i]) if is_limit else 1

            for d in range(low,up+1):
                if i != 0 and prev == 1 and d == 1:
                    pass
                else:
                    res += dfs(i+1, d, is_limit and d == up)
            return res 
        
        return dfs(0, 0, True) 

题目71012. 至少有 1 位重复的数字

解题思路:正难则反。求没有重复数位的数字的个数。非最小值限制的数位DP。

C++代码如下,

class Solution {
public:
    int numDupDigitsAtMostN(int n) {
        int lastn = n;
        string s = to_string(n);
        n = s.size();

        int memo[12][1<<10];
        memset(memo, -1, sizeof memo);
        
        //转换成求非重复数字的个数
        function<int(int,int,bool,bool)> dfs =[&] (int i, int mask, bool is_limit, bool is_num) -> int {
            if (!is_limit && is_num && memo[i][mask] != -1) {
                return memo[i][mask];
            }
            if (i == n) {
                return is_num;
            }
            int res = 0;
            if (!is_num) {
                res += dfs(i+1, 0, false, false);
            }
            int low = is_num ? 0 : 1;
            int up = is_limit ? s[i]-'0' : 9;
            for (int d = low; d < up+1; ++d) {
                if ((mask>>d&1) == 0) {
                    res += dfs(i+1, mask|1<<d, is_limit && d == up, true);
                } 
            }
            if (!is_limit && is_num) {
                memo[i][mask] = res;
            }
            return res;
        };
        return lastn - dfs(0, 0, true, false);
    }
};
//超时代码
class Solution {
public:
    int numDupDigitsAtMostN(int n) {
        string s = to_string(n);
        n = s.size();

        int memo[12][1<<10];
        memset(memo, -1, sizeof memo);
        
        function<int(int,int,bool,bool,bool)> dfs =[&] (int i, int mask, bool is_repeated, bool is_limit, bool is_num) -> int {
            if (is_repeated && !is_limit && is_num && memo[i][mask] != -1) {
                return memo[i][mask];
            }
            if (i == n) {
                return is_num && is_repeated;
            }
            int res = 0;
            if (!is_num) {
                res += dfs(i+1, 0, false, false, false);
            }
            int low = is_num ? 0 : 1;
            int up = is_limit ? s[i]-'0' : 9;
            for (int d = low; d < up+1; ++d) {
                if ((mask>>d&1) == 1) {
                    res += dfs(i+1, mask|1<<d, true, is_limit && d == up, true);
                } else {
                    res += dfs(i+1, mask|1<<d, is_repeated, is_limit && d == up, true);
                }
            }
            if (is_repeated && !is_limit && is_num) {
                memo[i][mask] = res;
            }
            return res;
        };
        return dfs(0, 0, false, true, false);
    }
};

python3代码如下,

class Solution:
    def numDupDigitsAtMostN(self, n: int) -> int:
        s = str(n)
        n = len(s)

        @cache
        def dfs(i: int, mask: int, is_repeated: bool, is_limit: bool, is_num: bool) -> int:
            if i == n:
                return is_repeated and is_num 
            res = 0
            if not is_num:
                res += dfs(i+1, 0, False, False, False)

            low = 0 if is_num else 1
            up = int(s[i]) if is_limit else 9 

            for d in range(low,up+1):
                if ((mask >> d) & 1) == 1:
                    res += dfs(i+1, mask|(1<<d), True, is_limit and d == up, True)
                else:
                    res += dfs(i+1, mask|(1<<d), is_repeated, is_limit and d == up, True)
            return res 
        return dfs(0, 0, False, True, False)

题目8357. 统计各位数字都不同的数字个数

解题思路:非最小值限制的数位DP。

C++代码如下,

class Solution {
public:
    int countNumbersWithUniqueDigits(int n) {
        if (n == 0) { //特判
            return 1;
        }
        string s(n,'9');

        int memo[10][1<<10];
        memset(memo, -1, sizeof memo);
        
        function<int(int,int,bool,bool)> dfs =[&] (int i, int mask, bool is_limit, bool is_num) -> int {
            if (!is_limit && is_num && memo[i][mask] != -1) {
                return memo[i][mask];
            }
            if (i == n) {
                return 1; //0000是合法方案
            }
            int res = 0;
            if (!is_num) {
                res += dfs(i+1, 0, false, false);
            }

            int low = is_num ? 0 : 1;
            int up = is_limit ? s[i]-'0' : 9;

            for (int d = low; d < up+1; ++d) {
                if ((mask>>d&1) == 0) {
                    res += dfs(i+1, mask|1<<d, is_limit && d == up, true);
                }
            }
            if (!is_limit && is_num) {
                memo[i][mask] = res;
            }
            return res;
        };
        return dfs(0, 0, true, false);
    }
};

python3代码如下,

class Solution:
    def countNumbersWithUniqueDigits(self, n: int) -> int:
        n = int(10**n)-1
        s = str(n)
        n = len(s)

        @cache
        def dfs(i: int, mask: int, is_limit: bool, is_num: bool) -> int:
            if i == n:
                return 1
            res = 0
            if not is_num:
                res += dfs(i+1, 0, False, False)
            low = 0 if is_num else 1 
            up = int(s[i]) if is_limit else 9 
            for d in range(low,up+1):
                if (mask>>d&1) == 0:
                    res += dfs(i+1, mask | 1 << d, is_limit and d == up, True)
            return res 
        return dfs(0, 0, True, False)  

题目93007. 价值和小于等于 K 的最大数字

解题思路:二分查找。无最小值限制的数位DP。

C++代码如下,

class Solution {
public:
    bool check(long long mid, const int &x, const long long &k) {
        //1到mid的总价值<=k
        string s = bitset<64>(mid).to_string();
        int n = s.size();

        long long memo[68][68];
        memset(memo, -1, sizeof memo);

        function<long long(int,int,bool)> dfs =[&] (int i, int curr, bool is_limit) -> long long {
            if (!is_limit && memo[i][curr] != -1) {
                return memo[i][curr];
            }
            if (i == n) {
                return curr;
            }
            long long res = 0;
            int low = 0;
            int up = is_limit ? s[i]-'0' : 1;
            for (int d = low; d < up+1; ++d) {
                if (d == 1 && (n-i) % x == 0) {
                    res += dfs(i+1, curr+1, is_limit && d == up);
                } else {
                    res += dfs(i+1, curr, is_limit && d == up);
                }
            }
            if (!is_limit) {
                memo[i][curr] = res;
            }
            return res;
        };

        return dfs(0,0,true) <= k;
    }

    long long findMaximumNumber(long long k, int x) {
        long long left = 0;
        long long right = (k+1)*(1<<x);
        long long res = -1;
        while (left <= right) {
            long long mid = (left + right) / 2;
            if (check(mid, x, k)) {
                res = mid;
                left = mid + 1;
            } else {
                right = mid - 1;
            }
        }
        return res;
    }
};

python3代码如下,

class Solution:
    def findMaximumNumber(self, k: int, x: int) -> int:
        def check(mid: int) -> bool:
            #1~mid的总价值小于等于k
            nonlocal k, x 
            s = bin(mid)[2:]
            n = len(s)

            @cache
            def dfs(i: int, curr: int, is_limit: bool) -> int:
                if i == n:
                    return curr 
                res = 0
                low = 0 
                up = int(s[i]) if is_limit else 1
                for d in range(low,up+1):
                    if (n-i) % x == 0 and d == 1: #注意这里是n-i
                        res += dfs(i+1, curr+1, is_limit and d == up)
                    else:
                        res += dfs(i+1, curr, is_limit and d == up)
                return res 
            return dfs(0, 0, True) <= k                 
        
        left = 0
        right = int((k+1)*2**x) #int(1e15) #随机选的一个上界
        res = -1
        while left <= right:
            mid = (left + right) // 2 
            if check(mid):
                res = mid 
                left = mid + 1
            else:
                right = mid - 1
        return res 

题目102827. 范围中美丽整数的数目

解题思路:

C++代码如下,

python3代码如下,

4 参考

分享丨【题单】动态规划(入门/背包/状态机/划分/区间/状压/数位/树形/数据结构优化)