DP:数位DP

文章讲述了如何运用动态规划(DP)解决两个编程问题:LC2376统计特殊整数和LC233统计数字1的个数,涉及状态转移方程、记忆化搜索以及处理前导零和上限限制的情况。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

数位DP的大致思想:枚举每一位能选取的合法值。

1. LC 2376 统计特殊整数

说是DP,但实际上状态转移方程挺难写的,毕竟是枚举+集合论,这里就不贴状态转移方程了。总体的写法其实是搜索+记忆化。之所以称之为DP,是因为:

  1. 对于第i位,
    1. 如果[0,i-1]位全部都选的能选的最大值,那么第i位最多也就只能选到最大值(注意第0位肯定是受限制的)
    2. 否则,第i位能随便选
  2. 如果前面[0,i-1]位都是0(前导零),这一位就可以随便选了

有这两大状态转移的规则,所以还是被称之为DP。细节写在代码里了

import java.util.Arrays;

class Solution {
    char[] s;
    int[][] memo;
    public int countSpecialNumbers(int n) {
        s = Integer.toString(n).toCharArray();

        // 记搜
        // memo[i][mask]代表[0,i-1]位选掉了mask中各个索引位置代表的数的情况下后面有多少个特殊整数
        memo = new int[s.length][1<<10];
        for (int i = 0; i < memo.length; i++) {
            Arrays.fill(memo[i],-1); // -1 代表没有计算过
        }
        // 一开始第一位当然要受限制,不过也可以是前导零,所以 (true,false)
        return f(0,0,true,false);
    }

    /**
     * 记搜计算在确定[0,i-1]位后剩下的有种特殊整数的情况
     * @param i 第i位
     * @param mask 已经选择了mask集合中的数
     * @param isLimit 当前位上限是否有限制,有的话是 int(s[i]),没有的话是9
     * @param isNum 当前位是否跳过,即前导零
     * @return 确定[0,i-1]位后剩下的有种特殊整数的情况
     */
    private int f(int i,int mask,boolean isLimit,boolean isNum){
        if(i==s.length){
            return isNum ?1:0; // 防止从头到尾的前导零,这种情况根本不是个数字,当然不能算
        }

        /*
          这个!isLimit条件非常重要,举个例子,n=420
          假设现在前两位选了 4,2,那么mask = {4,2},第三位就只能选0
          但如果前两位选了2,4,mask={4,2},第三位可以随便选(除了2,4)
          在第一种情况下能选1种,第二种情况下能选8种(10-2),差了7种情况
          而之后的循环枚举会考虑所有顶着最大上限选的情况,所以如果当前这个数位受到最大上限的限制的话,后面的for循环会统计这个情况的
          而不是把非受限制的情况的记忆化搜索结果返回,在这个例子里dp[2][{2,4}] = 8 而不是1
          但isNum不是必须判断的,因为如果前面都受限制,后面一位也一定受限制,可前面都是前导零,不代表后面也得是前导零
          况且,[0,i-1]位全是前导零的情况,顶多出现一次,后面再怎么递归都不可能有的
         */
        if(!isLimit && memo[i][mask]!=-1){
            return memo[i][mask];
        }

        int res = 0;
        if(!isNum){
            // 含前导零,跳过
            res = f(i+1,mask,false,false);
        }

        // 当前位置是否受限制
        int upper = isLimit?s[i]-'0':9;
        // 枚举可以填充的数,这里要检查是否之前使用了前导零,使用了的话要去掉(之前if(!isNum)加过了),没有用的话可以从0开始
        for (int j = isNum?0:1; j <= upper; j++) {
            // 特殊整数要求各数位都不一样,所以检查mask中之前用过没
            // mask就是个位图,比如 binary(mask) = 0101010111 从右往左看,用集合论表示就是 {8,6,4,2,1,0} 这些数字已经被用过了
            if((mask>>j&1)==0){
                // 用掉j就是把mask的第j位设置为0 mask|(1<<j)
                // 那么下一个数位是否受限制呢?如果当前数位受限制并且当前数位选取了上限,那么下一位是受限制的,否则不受限制
                // 例如,n = 123 i=1,假设当前数位不受上限,那么也就是说这个数<=99,下一位当然也不受限制
                // 但若受限制,则代表前面i=0的时候选择了上限1,i=1选择最多是2,下一位自然受限制,最大是3
                // 最后,既然这一位枚举了一位有意义的数字,后续的枚举自然也是有意义的, isNum 为 true
                res += f(i+1,mask|(1<<j),isLimit && j==upper,true);
            }
        }

        // 记忆化
        if(!isLimit){
            memo[i][mask] = res;
        }

        return res;
    }
}

2. LC 233 数字1的个数

这题是我学了板子后做的第一题,真的汗流浃背。想想做做调调,卡了1h才出来。

首先这个跟上面一题的区别是,对于枚举的数字没有限制。不仅没有重复性的限制,而且还没有前导零的限制(前导零的数不会被判无效,因为这道题只看1的数量)

这样就简便了不少。我们只需要看是否被上限限制即可。这个是否受限还是和以前一样,只有前面的都受限,本次才会受限,否则不会。

令f(i,isLimit)表示第[i,n-1]位在Limit的限制下能产生的1的数量。那么本轮的上限可以由isLimit计算得出:

  1. 如果upper<1,也就是upper=0的情况,本轮不可能选1

  2. 如果upper==1,本轮选择1会产生 int(suffix(n,i+1))+1 个1。其中suffix(n,i+1)代表n的[i+1,n-1]位的值,例如 n = 2132 , i=1 那么suffix(2232,2) = 32。

    这是比较显然的,拿上面那个例子来说,如果本轮选择1,后续会有2100到2132这些数的第i=1位是1,所以就是32-00+1=33个

  3. 如果upper>1,本轮选择1会产生 pow(10,n-i-1)个1。例如n=2232,i=1,那么如果本轮选择1,就会有2100到2199的第i=1位是1,也就是pow(4-1-1)=100个

以上考虑的是本轮(第i位)产生的1的数量。后面[i+1,n-1]产生的还没算:

两种情况讨论。上限是upper,说明本轮有upper种选择,其中可能有一种是顶格选的(isLimit情况),有upper种是非顶格选的。依次累加到res即可。这里对于前者,根据当前的isLimit来(如果当前顶格了后面也得顶格,当前不顶格后面也不顶格),根据后者,isLimit = false

最后,记搜的时候要记得排除顶格选的情况。因为这种情况已经被统计过了。

import java.util.Arrays;

class Solution {
    char[] s;
    int[] memo;
    public int countDigitOne(int n) {
        s = Integer.toString(n).toCharArray();
        memo = new int[s.length];
        Arrays.fill(memo,-1);
        return f(0,true);
    }

    /**
     * 记搜计算[0,i-1]位选择完毕后,后面的位置总共能出现多少个1
     * @param i 第i位
     * @param isLimit 受到最大上限限制与否
     * @return [0,i-1]位选择完毕后,后面的位置总共能出现多少个1
     */
    private int f(int i,boolean isLimit){
        if(i==s.length){
            return 0;
        }

        /*
        例如:n = 1230 ,现在 i=[0,1,2] = {1,2,3},那么后面一个1都不可能有
        但如果i=[0,1,2] = {1,2,2},后面是可以有一个1的
        memo记录的是后者
         */
        if(!isLimit && memo[i]!=-1){
            return memo[i];
        }

        int res = 0;
        int upper = isLimit?s[i]-'0':9;
        // 如果这一轮选1
        if(upper==1){
            res += suffix(i+1)+1;
        }else if(upper>1){
            res += (int) Math.pow(10,s.length-i-1);
        }
        // 本来可以选 upper+1个数(这一轮)
        // 如果之前全部都顶格选了,那么将是upper个可以后续不用顶格选的,和一个必须顶格选的
        res += upper*f(i+1,false) + f(i+1,isLimit);

        if(!isLimit){
            memo[i] = res;
        }

        return res;
    }

    private int suffix(int start){
        StringBuilder sb = new StringBuilder();
        for(int i=start;i<s.length;i++){
            sb.append(s[i]);
        }
        return sb.isEmpty()?0:Integer.parseInt(sb.toString());
    }
}

3. LC 2719 统计整数数目

这道题我思路有的,但就是有点歪,所以虽然A了但是时间上表现不好

首先我的记搜是包含4个状态的:定义 f (i,isLower,isUpper,acc)表示在[0,i-1]位均已枚举,且数位和为acc,且是(否)受下限与上限的制约的情况下,后续能够产生的符合条件的数。

那么上限和下限分别怎么算?我通过补齐较小的num1的前导零,使其与num2在数位长度上等长。这样下限由num1(补齐前导零后)决定,上限由num2决定。

在深搜时记忆化在不受上下限制约的情况下,在枚举到第i位且已有数位累计和acc的情况下,后续能有多少个符合条件的数。

之后根据是否受上下限制约枚举数位即可。这里注意枚举时可以及时地判断是否已经爆掉数位和上界了,而下界可以留到最终递归基的是否判断。

最后,我现在是觉得,模运算这个东西,有很强的性质(加法乘法的性质都特别强),如果担心答案爆了怎么办,就在能取模的地方全部取模就行。

import java.util.Arrays;

class Solution {
    static long mod = (long)1e9+7;

    char[] s1;
    char[] s2;

    long[][] memo;

    int min;
    int max;
    public int count(String num1, String num2, int min_sum, int max_sum) {
        min = min_sum;
        max = max_sum;
        s1 = supplyLeadingZero(num1,num2).toCharArray();
        s2 = num2.toCharArray();

        // memo[i][acc]代表在不受限制的情况下 到了第i位已经有acc的数位和,第[i+1,n-1]位最多能有多少个符合条件的数
        memo = new long[s2.length][22*9+1];
        for (int i = 0; i < memo.length; i++) {
            Arrays.fill(memo[i],-1L);
        }

        return (int) (f(0,true,true,0) % mod);
    }

    private String supplyLeadingZero(String num1,String num2){
        StringBuilder num1Builder = new StringBuilder(num1);
        while(num1Builder.length()<num2.length()){
            num1Builder.insert(0, "0");
        }
        num1 = num1Builder.toString();
        return num1;
    }

    private long f(int i,boolean isLower,boolean isUpper,int acc){
        if(i==s2.length){
            return acc>=min?1L:0L;
        }

        if(!isLower && !isUpper && memo[i][acc]!=-1){
            return memo[i][acc];
        }

        int lb = isLower?s1[i]-'0':0;
        int ub = isUpper?s2[i]-'0':9;

        long res = 0L;
        for(int j=lb;j<=ub;j++){
            if(max>=acc+j){
                res = (res%mod + f(i+1,isLower && j==lb, isUpper && j==ub, acc+j)%mod) % mod;
            }
        }

        if(!isLower && !isUpper){
            memo[i][acc] = res;
        }

        return res;
    }
}

还有一种更常见的思路是,先统计一遍≤num1的情况,再统计一遍≤num2的情况,然后后者减前者就是(num1,num2]的情况。又因为题目是闭区间,所以单独判一下num1符合条件与否即可。这种思路跑得比我的代码快,这里摘录一份:

class Solution {
    private static final int MOD = 1_000_000_007;

    public int count(String num1, String num2, int minSum, int maxSum) {
        int ans = calc(num2, minSum, maxSum) - calc(num1, minSum, maxSum) + MOD; // 避免负数
        int sum = 0;
        for (char c : num1.toCharArray()) {
            sum += c - '0';
        }
        if (minSum <= sum && sum <= maxSum) {
            ans++; // num1 是合法的,补回来
        }
        return ans % MOD;
    }

    private int calc(String s, int minSum, int maxSum) {
        int n = s.length();
        int[][] memo = new int[n][Math.min(9 * n, maxSum) + 1];
        for (int[] row : memo) {
            Arrays.fill(row, -1);
        }
        return dfs(0, 0, true, s.toCharArray(), minSum, maxSum, memo);
    }

    private int dfs(int i, int sum, boolean isLimit, char[] s, int minSum, int maxSum, int[][] memo) {
        if (sum > maxSum) { // 非法
            return 0;
        }
        if (i == s.length) {
            return sum >= minSum ? 1 : 0;
        }
        if (!isLimit && memo[i][sum] != -1) {
            return memo[i][sum];
        }

        int up = isLimit ? s[i] - '0' : 9;
        int res = 0;
        for (int d = 0; d <= up; d++) { // 枚举当前数位填 d
            res = (res + dfs(i + 1, sum + d, isLimit && (d == up), s, minSum, maxSum, memo)) % MOD;
        }

        if (!isLimit) {
            memo[i][sum] = res;
        }
        return res;
    }
}

4. LC 600 不含连续1的非负整数

实际上我们可以发现,除了数位和n相同的数字,其他数字只要确定了长度,除了第一个二进制数位必须是1(我们可以单独另算0),其他的都可以任意选择。所以分两种情况:

  1. 正在计算的是数位长度<n的数位长度的数字
  2. 正在计算的是数位长度=n的数位长度的数字

对于第一种情况,维护前一个数位上的0-1标记即可。

对于第二种情况,有两种状态:

  1. 前面的数位顶满了。例如最大110b,现在已经11了,后面那个就只能选0而不能选1。
  2. 前面的数位没顶满,这里就可以随便选,只要不和前面连续的1即可。

py记搜,cpp改了个递推的dp。

from functools import cache

class Solution:
    def findIntegers(self, n: int) -> int:
        tmp = n
        bit_len = 0
        bits = [-1]
        while tmp:
            bits.append(tmp&1)
            tmp >>= 1
            bit_len += 1

        @cache
        def dfs(prev:int,rest:int,full:int)->int:
            if not rest:
                return 1
            
            res = 0
            # 前面顶满了
            if full:
                for i in range(bits[rest]+1):
                    if prev and prev==i:
                        continue
                    res += dfs(i,rest-1,(1 if i==bits[rest] else 0))
            # 前面没顶满
            else:
                for i in range(2):
                    if prev and prev==i:
                        continue
                    res += dfs(i,rest-1,0)
            
            return res
        
        ans = 1
        for i in range(1,bit_len+1):
            if i==bit_len:
                ans += dfs(1,bit_len-1,1)
            else:
                ans += dfs(1,i-1,0)
        return ans
        
#include <vector> 

using namespace std;

class Solution {
public:
    int findIntegers(int n) {
        int tmp = n;
        vector<int> bits{-1};
        int bit_len = 0;
        while(tmp){
            bits.push_back(tmp&1);
            tmp >>= 1;
            bit_len += 1;
        }

        vector<vector<vector<int>>> dp(bit_len,vector<vector<int>>(2,vector<int>(2)));

        int ans = 1;
        dp[0][1][0] = 1;
        dp[0][1][1] = 1;
        dp[0][0][1] = 1;
        dp[0][0][0] = 1;

        int i;
        for(i=0;i<bit_len;i++){
            if(i==0){
                ans += dp[i][1][0];
            }else{
                dp[i][1][1] = dp[i-1][0][bits[i]==0];
                dp[i][0][1] = dp[i-1][0][bits[i]==0] + (bits[i]?dp[i-1][1][1]:0);
                dp[i][1][0] = dp[i-1][0][0];
                dp[i][0][0] = dp[i-1][1][0] + dp[i-1][0][0];
                if(i<bit_len-1){
                    ans += dp[i][1][0];
                }else{
                    ans += dp[i][1][1];
                }
            }
        }

        return ans;
    }
};

5. LC 3129/3130 找出所有稳定的二进制数组Ⅱ

这题我一开始的思路就是力大砖飞。直接状压了前limit个元素,然后开了个长为limit的全1掩码判断是否只能增加0或者1(位与如果和mask相等说明全1,只能增加0;如果为0说明全0,只能增加1),还是都能走。然后再加个记搜写了个深搜。

from functools import cache
import math

mod = 10**9+7

class Solution:
    def numberOfStableArrays(self, zero: int, one: int, limit: int) -> int:
        n = zero+one

        if limit>n:
            limit = n
        
        mask = 0
        for _ in range(limit):
            mask <<= 1
            mask |= 1
            
        
        @cache
        def dfs(i:int,num:int,x:int)->int:
            if x>zero or i-1-x>one:
                return 0
            if i>n:
                return 1

            res = 0
            if (mask&num)==mask:
                num <<= 1
                num &= mask
                res = (res + dfs(i+1,num,x+1)%mod)%mod
            elif (mask&num)==0:
                num <<= 1
                num |= 1
                num &= mask
                res = (res + dfs(i+1,num,x)%mod)%mod
            else:
                res = (res + (dfs(i+1,(num<<1)&mask,x+1)%mod + dfs(i+1,((num<<1)|1)&mask,x)%mod)%mod)%mod
            
            return res

        ans = 0
        for i in range(2**limit):
            ans = (ans + dfs(limit+1,i,limit-bin(i).count('1'))%mod)%mod
            # print(ans)
        return ans

这个有个比较明显的问题,因为limit∈[1,200],所以状压就是2^200次方,天文数字,直接MLE。

然后实在不会,看了灵神的题解和视频:

题解:. - 力扣(LeetCode) 视频:动态规划【力扣双周赛 129】_哔哩哔哩_bilibili

首先limit的含义是:至多有连续limit个0或者1。

定义dp(i,j,k)为在还剩i个0,j个1的情况下在第i+j的数位上放置数字k,其中k∈{0,1}。那么:

dp(i,j,0) = dp(i-1,j,0) + dp(i-1,j,1)

注意,k是当前放置的数字,所以dp(i-1,j,0)代表的是之前在第i+j位置上放了个0,然后现在在第i+j-1个位置上再放置一个0。

但是这样是不对的。因为可能会连续防止超过limit个0或者1。因此我们要想办法把这些不合法的情况减掉。

如果当前放置了0,那么再往前limit个数如果全是0,则0的个数还剩下i-limit-1个。此时再往前必须放1了。不然就爆掉了。

因此:

dp(i,j,0) = dp(i-1,j,0) + dp(i-1,j,1) - dp(i-limit-1,j,1)

其实到这里我有个疑问,为什么不再减掉:

dp(i-limit-1,j,0)

呢?

先说结论,因为在计算dp(i-1,j,0)时减掉过了:

dp(i-1,j,0) = dp(i-2,j,0) + dp(i-2,j,1) - dp(i-1-limit-1,j,1)

dp(i-limit-2,j,1)

本质上就是我们想减掉的dp(i-limit-1,j,0)中的一部分,当然还有dp(i-limit-3,j,1)、dp(i-limit-4,j,1)这些,共同组成了dp(i-limit-1,j,0)。

来看一个比较形象的例子,假设limit=2。那么:

  1. 100(0)是合法的。这个数累计到了dp(i-3,j,1)里面,但是再加上现在0,就不合法了。
  2. 1000(0)本身就是不合法的,这种情况在dp(i-1,j,0)里面就已经筛掉了。自然包含在dp(i-limit-1,j,0)里面。
from functools import cache

mod = 10**9+7

class Solution:
    def numberOfStableArrays(self, zero: int, one: int, limit: int) -> int:
        
        @cache
        def dfs(i:int,j:int,k:int)->int:
            if i==0:
                return 1 if j<=limit and k==1 else 0
            if j==0:
                return 1 if i<=limit and k==0 else 0
            if k==0:
                return (dfs(i-1,j,0) + dfs(i-1,j,1) - (dfs(i-limit-1,j,1) if i>limit else 0) )%mod
            else:
                return (dfs(i,j-1,0) + dfs(i,j-1,1) - (dfs(i,j-limit-1,0) if j>limit else 0) )%mod
        
        ans = (dfs(zero,one,0)+dfs(zero,one,1))%mod
        dfs.cache_clear()
        return ans

#include <vector>

using namespace std;

const int MOD = 1e9+7;

class Solution {
public:
    int numberOfStableArrays(int zero, int one, int limit) {
        vector<vector<vector<int>>> 
            dp(zero+1,vector<vector<int>>(one+1,vector<int>(2)));

        int i;
        for(i=0;i<=min(limit,one);i++){
            dp[0][i][1] = 1;
        }
        for(i=0;i<=min(limit,zero);i++){
            dp[i][0][0] = 1;
        }

        int j;
        for(i=1;i<=zero;i++){
            for(j=1;j<=one;j++){
                dp[i][j][0] = ((long long)dp[i-1][j][0] + dp[i-1][j][1] - (i>limit?dp[i-limit-1][j][1]:0) + MOD)%MOD;
                dp[i][j][1] = ((long long)dp[i][j-1][0] + dp[i][j-1][1] - (j>limit?dp[i][j-limit-1][0]:0) + MOD)%MOD;
            }
        }
        return (dp[zero][one][0]+dp[zero][one][1])%MOD;
    }
};

6. LC 3007 价值和小于等于K的最大数组

首先1-n中,这个n越大,累计和越大。所以是单调的,可以二分。

在计算1-mid中的累计和时,可以借鉴:

233. 数字 1 的个数 - 力扣(LeetCode)

这题,套数位dp板子,枚举选或不选,只是要加上数位索引i%x==0的条件,以满足设置位为1的要求。

from functools import cache

class Solution:
    def findMaximumNumber(self, k: int, x: int) -> int:
        def sum(mid:int)->int:
            @cache
            def dfs(i:int,cnt:int,is_Limit:bool)->int:
                if i==0:
                    return cnt
                
                up = (mid>>(i-1))&1 if is_Limit else 1
                res = 0
                for d in range(up+1):
                    res += dfs(i-1,cnt+(d==1 and i%x==0),is_Limit and d==up)
                return res
            return dfs(mid.bit_length(),0,True)
        
        l,r,ans = 1,(k+1)*256,1
        while l<r:
            mid = (l+r)>>1
            if sum(mid)>k:
                r = mid
            else:
                ans = mid
                l = mid+1
        
        return ans

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值