数位DP的大致思想:枚举每一位能选取的合法值。
1. LC 2376 统计特殊整数
说是DP,但实际上状态转移方程挺难写的,毕竟是枚举+集合论,这里就不贴状态转移方程了。总体的写法其实是搜索+记忆化。之所以称之为DP,是因为:
- 对于第i位,
- 如果[0,i-1]位全部都选的能选的最大值,那么第i位最多也就只能选到最大值(注意第0位肯定是受限制的)
- 否则,第i位能随便选
- 如果前面[0,i-1]位都是0(前导零),这一位就可以随便选了
有这两大状态转移的规则,所以还是被称之为DP。细节写在代码里了
import java.util.Arrays;
class Solution {
char[] s;
int[][] memo;
public int countSpecialNumbers(int n) {
s = Integer.toString(n).toCharArray();
// 记搜
// memo[i][mask]代表[0,i-1]位选掉了mask中各个索引位置代表的数的情况下后面有多少个特殊整数
memo = new int[s.length][1<<10];
for (int i = 0; i < memo.length; i++) {
Arrays.fill(memo[i],-1); // -1 代表没有计算过
}
// 一开始第一位当然要受限制,不过也可以是前导零,所以 (true,false)
return f(0,0,true,false);
}
/**
* 记搜计算在确定[0,i-1]位后剩下的有种特殊整数的情况
* @param i 第i位
* @param mask 已经选择了mask集合中的数
* @param isLimit 当前位上限是否有限制,有的话是 int(s[i]),没有的话是9
* @param isNum 当前位是否跳过,即前导零
* @return 确定[0,i-1]位后剩下的有种特殊整数的情况
*/
private int f(int i,int mask,boolean isLimit,boolean isNum){
if(i==s.length){
return isNum ?1:0; // 防止从头到尾的前导零,这种情况根本不是个数字,当然不能算
}
/*
这个!isLimit条件非常重要,举个例子,n=420
假设现在前两位选了 4,2,那么mask = {4,2},第三位就只能选0
但如果前两位选了2,4,mask={4,2},第三位可以随便选(除了2,4)
在第一种情况下能选1种,第二种情况下能选8种(10-2),差了7种情况
而之后的循环枚举会考虑所有顶着最大上限选的情况,所以如果当前这个数位受到最大上限的限制的话,后面的for循环会统计这个情况的
而不是把非受限制的情况的记忆化搜索结果返回,在这个例子里dp[2][{2,4}] = 8 而不是1
但isNum不是必须判断的,因为如果前面都受限制,后面一位也一定受限制,可前面都是前导零,不代表后面也得是前导零
况且,[0,i-1]位全是前导零的情况,顶多出现一次,后面再怎么递归都不可能有的
*/
if(!isLimit && memo[i][mask]!=-1){
return memo[i][mask];
}
int res = 0;
if(!isNum){
// 含前导零,跳过
res = f(i+1,mask,false,false);
}
// 当前位置是否受限制
int upper = isLimit?s[i]-'0':9;
// 枚举可以填充的数,这里要检查是否之前使用了前导零,使用了的话要去掉(之前if(!isNum)加过了),没有用的话可以从0开始
for (int j = isNum?0:1; j <= upper; j++) {
// 特殊整数要求各数位都不一样,所以检查mask中之前用过没
// mask就是个位图,比如 binary(mask) = 0101010111 从右往左看,用集合论表示就是 {8,6,4,2,1,0} 这些数字已经被用过了
if((mask>>j&1)==0){
// 用掉j就是把mask的第j位设置为0 mask|(1<<j)
// 那么下一个数位是否受限制呢?如果当前数位受限制并且当前数位选取了上限,那么下一位是受限制的,否则不受限制
// 例如,n = 123 i=1,假设当前数位不受上限,那么也就是说这个数<=99,下一位当然也不受限制
// 但若受限制,则代表前面i=0的时候选择了上限1,i=1选择最多是2,下一位自然受限制,最大是3
// 最后,既然这一位枚举了一位有意义的数字,后续的枚举自然也是有意义的, isNum 为 true
res += f(i+1,mask|(1<<j),isLimit && j==upper,true);
}
}
// 记忆化
if(!isLimit){
memo[i][mask] = res;
}
return res;
}
}
2. LC 233 数字1的个数
这题是我学了板子后做的第一题,真的汗流浃背。想想做做调调,卡了1h才出来。
首先这个跟上面一题的区别是,对于枚举的数字没有限制。不仅没有重复性的限制,而且还没有前导零的限制(前导零的数不会被判无效,因为这道题只看1的数量)
这样就简便了不少。我们只需要看是否被上限限制即可。这个是否受限还是和以前一样,只有前面的都受限,本次才会受限,否则不会。
令f(i,isLimit)表示第[i,n-1]位在Limit的限制下能产生的1的数量。那么本轮的上限可以由isLimit计算得出:
-
如果upper<1,也就是upper=0的情况,本轮不可能选1
-
如果upper==1,本轮选择1会产生 int(suffix(n,i+1))+1 个1。其中suffix(n,i+1)代表n的[i+1,n-1]位的值,例如 n = 2132 , i=1 那么suffix(2232,2) = 32。
这是比较显然的,拿上面那个例子来说,如果本轮选择1,后续会有2100到2132这些数的第i=1位是1,所以就是32-00+1=33个
-
如果upper>1,本轮选择1会产生 pow(10,n-i-1)个1。例如n=2232,i=1,那么如果本轮选择1,就会有2100到2199的第i=1位是1,也就是pow(4-1-1)=100个
以上考虑的是本轮(第i位)产生的1的数量。后面[i+1,n-1]产生的还没算:
两种情况讨论。上限是upper,说明本轮有upper种选择,其中可能有一种是顶格选的(isLimit情况),有upper种是非顶格选的。依次累加到res即可。这里对于前者,根据当前的isLimit来(如果当前顶格了后面也得顶格,当前不顶格后面也不顶格),根据后者,isLimit = false
最后,记搜的时候要记得排除顶格选的情况。因为这种情况已经被统计过了。
import java.util.Arrays;
class Solution {
char[] s;
int[] memo;
public int countDigitOne(int n) {
s = Integer.toString(n).toCharArray();
memo = new int[s.length];
Arrays.fill(memo,-1);
return f(0,true);
}
/**
* 记搜计算[0,i-1]位选择完毕后,后面的位置总共能出现多少个1
* @param i 第i位
* @param isLimit 受到最大上限限制与否
* @return [0,i-1]位选择完毕后,后面的位置总共能出现多少个1
*/
private int f(int i,boolean isLimit){
if(i==s.length){
return 0;
}
/*
例如:n = 1230 ,现在 i=[0,1,2] = {1,2,3},那么后面一个1都不可能有
但如果i=[0,1,2] = {1,2,2},后面是可以有一个1的
memo记录的是后者
*/
if(!isLimit && memo[i]!=-1){
return memo[i];
}
int res = 0;
int upper = isLimit?s[i]-'0':9;
// 如果这一轮选1
if(upper==1){
res += suffix(i+1)+1;
}else if(upper>1){
res += (int) Math.pow(10,s.length-i-1);
}
// 本来可以选 upper+1个数(这一轮)
// 如果之前全部都顶格选了,那么将是upper个可以后续不用顶格选的,和一个必须顶格选的
res += upper*f(i+1,false) + f(i+1,isLimit);
if(!isLimit){
memo[i] = res;
}
return res;
}
private int suffix(int start){
StringBuilder sb = new StringBuilder();
for(int i=start;i<s.length;i++){
sb.append(s[i]);
}
return sb.isEmpty()?0:Integer.parseInt(sb.toString());
}
}
3. LC 2719 统计整数数目
这道题我思路有的,但就是有点歪,所以虽然A了但是时间上表现不好
首先我的记搜是包含4个状态的:定义 f (i,isLower,isUpper,acc)表示在[0,i-1]位均已枚举,且数位和为acc,且是(否)受下限与上限的制约的情况下,后续能够产生的符合条件的数。
那么上限和下限分别怎么算?我通过补齐较小的num1的前导零,使其与num2在数位长度上等长。这样下限由num1(补齐前导零后)决定,上限由num2决定。
在深搜时记忆化在不受上下限制约的情况下,在枚举到第i位且已有数位累计和acc的情况下,后续能有多少个符合条件的数。
之后根据是否受上下限制约枚举数位即可。这里注意枚举时可以及时地判断是否已经爆掉数位和上界了,而下界可以留到最终递归基的是否判断。
最后,我现在是觉得,模运算这个东西,有很强的性质(加法乘法的性质都特别强),如果担心答案爆了怎么办,就在能取模的地方全部取模就行。
import java.util.Arrays;
class Solution {
static long mod = (long)1e9+7;
char[] s1;
char[] s2;
long[][] memo;
int min;
int max;
public int count(String num1, String num2, int min_sum, int max_sum) {
min = min_sum;
max = max_sum;
s1 = supplyLeadingZero(num1,num2).toCharArray();
s2 = num2.toCharArray();
// memo[i][acc]代表在不受限制的情况下 到了第i位已经有acc的数位和,第[i+1,n-1]位最多能有多少个符合条件的数
memo = new long[s2.length][22*9+1];
for (int i = 0; i < memo.length; i++) {
Arrays.fill(memo[i],-1L);
}
return (int) (f(0,true,true,0) % mod);
}
private String supplyLeadingZero(String num1,String num2){
StringBuilder num1Builder = new StringBuilder(num1);
while(num1Builder.length()<num2.length()){
num1Builder.insert(0, "0");
}
num1 = num1Builder.toString();
return num1;
}
private long f(int i,boolean isLower,boolean isUpper,int acc){
if(i==s2.length){
return acc>=min?1L:0L;
}
if(!isLower && !isUpper && memo[i][acc]!=-1){
return memo[i][acc];
}
int lb = isLower?s1[i]-'0':0;
int ub = isUpper?s2[i]-'0':9;
long res = 0L;
for(int j=lb;j<=ub;j++){
if(max>=acc+j){
res = (res%mod + f(i+1,isLower && j==lb, isUpper && j==ub, acc+j)%mod) % mod;
}
}
if(!isLower && !isUpper){
memo[i][acc] = res;
}
return res;
}
}
还有一种更常见的思路是,先统计一遍≤num1的情况,再统计一遍≤num2的情况,然后后者减前者就是(num1,num2]的情况。又因为题目是闭区间,所以单独判一下num1符合条件与否即可。这种思路跑得比我的代码快,这里摘录一份:
class Solution {
private static final int MOD = 1_000_000_007;
public int count(String num1, String num2, int minSum, int maxSum) {
int ans = calc(num2, minSum, maxSum) - calc(num1, minSum, maxSum) + MOD; // 避免负数
int sum = 0;
for (char c : num1.toCharArray()) {
sum += c - '0';
}
if (minSum <= sum && sum <= maxSum) {
ans++; // num1 是合法的,补回来
}
return ans % MOD;
}
private int calc(String s, int minSum, int maxSum) {
int n = s.length();
int[][] memo = new int[n][Math.min(9 * n, maxSum) + 1];
for (int[] row : memo) {
Arrays.fill(row, -1);
}
return dfs(0, 0, true, s.toCharArray(), minSum, maxSum, memo);
}
private int dfs(int i, int sum, boolean isLimit, char[] s, int minSum, int maxSum, int[][] memo) {
if (sum > maxSum) { // 非法
return 0;
}
if (i == s.length) {
return sum >= minSum ? 1 : 0;
}
if (!isLimit && memo[i][sum] != -1) {
return memo[i][sum];
}
int up = isLimit ? s[i] - '0' : 9;
int res = 0;
for (int d = 0; d <= up; d++) { // 枚举当前数位填 d
res = (res + dfs(i + 1, sum + d, isLimit && (d == up), s, minSum, maxSum, memo)) % MOD;
}
if (!isLimit) {
memo[i][sum] = res;
}
return res;
}
}
4. LC 600 不含连续1的非负整数
实际上我们可以发现,除了数位和n相同的数字,其他数字只要确定了长度,除了第一个二进制数位必须是1(我们可以单独另算0),其他的都可以任意选择。所以分两种情况:
- 正在计算的是数位长度<n的数位长度的数字
- 正在计算的是数位长度=n的数位长度的数字
对于第一种情况,维护前一个数位上的0-1标记即可。
对于第二种情况,有两种状态:
- 前面的数位顶满了。例如最大110b,现在已经11了,后面那个就只能选0而不能选1。
- 前面的数位没顶满,这里就可以随便选,只要不和前面连续的1即可。
py记搜,cpp改了个递推的dp。
from functools import cache
class Solution:
def findIntegers(self, n: int) -> int:
tmp = n
bit_len = 0
bits = [-1]
while tmp:
bits.append(tmp&1)
tmp >>= 1
bit_len += 1
@cache
def dfs(prev:int,rest:int,full:int)->int:
if not rest:
return 1
res = 0
# 前面顶满了
if full:
for i in range(bits[rest]+1):
if prev and prev==i:
continue
res += dfs(i,rest-1,(1 if i==bits[rest] else 0))
# 前面没顶满
else:
for i in range(2):
if prev and prev==i:
continue
res += dfs(i,rest-1,0)
return res
ans = 1
for i in range(1,bit_len+1):
if i==bit_len:
ans += dfs(1,bit_len-1,1)
else:
ans += dfs(1,i-1,0)
return ans
#include <vector>
using namespace std;
class Solution {
public:
int findIntegers(int n) {
int tmp = n;
vector<int> bits{-1};
int bit_len = 0;
while(tmp){
bits.push_back(tmp&1);
tmp >>= 1;
bit_len += 1;
}
vector<vector<vector<int>>> dp(bit_len,vector<vector<int>>(2,vector<int>(2)));
int ans = 1;
dp[0][1][0] = 1;
dp[0][1][1] = 1;
dp[0][0][1] = 1;
dp[0][0][0] = 1;
int i;
for(i=0;i<bit_len;i++){
if(i==0){
ans += dp[i][1][0];
}else{
dp[i][1][1] = dp[i-1][0][bits[i]==0];
dp[i][0][1] = dp[i-1][0][bits[i]==0] + (bits[i]?dp[i-1][1][1]:0);
dp[i][1][0] = dp[i-1][0][0];
dp[i][0][0] = dp[i-1][1][0] + dp[i-1][0][0];
if(i<bit_len-1){
ans += dp[i][1][0];
}else{
ans += dp[i][1][1];
}
}
}
return ans;
}
};
5. LC 3129/3130 找出所有稳定的二进制数组Ⅱ
这题我一开始的思路就是力大砖飞。直接状压了前limit个元素,然后开了个长为limit的全1掩码判断是否只能增加0或者1(位与如果和mask相等说明全1,只能增加0;如果为0说明全0,只能增加1),还是都能走。然后再加个记搜写了个深搜。
from functools import cache
import math
mod = 10**9+7
class Solution:
def numberOfStableArrays(self, zero: int, one: int, limit: int) -> int:
n = zero+one
if limit>n:
limit = n
mask = 0
for _ in range(limit):
mask <<= 1
mask |= 1
@cache
def dfs(i:int,num:int,x:int)->int:
if x>zero or i-1-x>one:
return 0
if i>n:
return 1
res = 0
if (mask&num)==mask:
num <<= 1
num &= mask
res = (res + dfs(i+1,num,x+1)%mod)%mod
elif (mask&num)==0:
num <<= 1
num |= 1
num &= mask
res = (res + dfs(i+1,num,x)%mod)%mod
else:
res = (res + (dfs(i+1,(num<<1)&mask,x+1)%mod + dfs(i+1,((num<<1)|1)&mask,x)%mod)%mod)%mod
return res
ans = 0
for i in range(2**limit):
ans = (ans + dfs(limit+1,i,limit-bin(i).count('1'))%mod)%mod
# print(ans)
return ans
这个有个比较明显的问题,因为limit∈[1,200],所以状压就是2^200次方,天文数字,直接MLE。
然后实在不会,看了灵神的题解和视频:
首先limit的含义是:至多有连续limit个0或者1。
定义dp(i,j,k)为在还剩i个0,j个1的情况下在第i+j的数位上放置数字k,其中k∈{0,1}。那么:
注意,k是当前放置的数字,所以dp(i-1,j,0)代表的是之前在第i+j位置上放了个0,然后现在在第i+j-1个位置上再放置一个0。
但是这样是不对的。因为可能会连续防止超过limit个0或者1。因此我们要想办法把这些不合法的情况减掉。
如果当前放置了0,那么再往前limit个数如果全是0,则0的个数还剩下i-limit-1个。此时再往前必须放1了。不然就爆掉了。
因此:
其实到这里我有个疑问,为什么不再减掉:
呢?
先说结论,因为在计算dp(i-1,j,0)时减掉过了:
而
本质上就是我们想减掉的dp(i-limit-1,j,0)中的一部分,当然还有dp(i-limit-3,j,1)、dp(i-limit-4,j,1)这些,共同组成了dp(i-limit-1,j,0)。
来看一个比较形象的例子,假设limit=2。那么:
- 100(0)是合法的。这个数累计到了dp(i-3,j,1)里面,但是再加上现在0,就不合法了。
- 1000(0)本身就是不合法的,这种情况在dp(i-1,j,0)里面就已经筛掉了。自然包含在dp(i-limit-1,j,0)里面。
from functools import cache
mod = 10**9+7
class Solution:
def numberOfStableArrays(self, zero: int, one: int, limit: int) -> int:
@cache
def dfs(i:int,j:int,k:int)->int:
if i==0:
return 1 if j<=limit and k==1 else 0
if j==0:
return 1 if i<=limit and k==0 else 0
if k==0:
return (dfs(i-1,j,0) + dfs(i-1,j,1) - (dfs(i-limit-1,j,1) if i>limit else 0) )%mod
else:
return (dfs(i,j-1,0) + dfs(i,j-1,1) - (dfs(i,j-limit-1,0) if j>limit else 0) )%mod
ans = (dfs(zero,one,0)+dfs(zero,one,1))%mod
dfs.cache_clear()
return ans
#include <vector>
using namespace std;
const int MOD = 1e9+7;
class Solution {
public:
int numberOfStableArrays(int zero, int one, int limit) {
vector<vector<vector<int>>>
dp(zero+1,vector<vector<int>>(one+1,vector<int>(2)));
int i;
for(i=0;i<=min(limit,one);i++){
dp[0][i][1] = 1;
}
for(i=0;i<=min(limit,zero);i++){
dp[i][0][0] = 1;
}
int j;
for(i=1;i<=zero;i++){
for(j=1;j<=one;j++){
dp[i][j][0] = ((long long)dp[i-1][j][0] + dp[i-1][j][1] - (i>limit?dp[i-limit-1][j][1]:0) + MOD)%MOD;
dp[i][j][1] = ((long long)dp[i][j-1][0] + dp[i][j-1][1] - (j>limit?dp[i][j-limit-1][0]:0) + MOD)%MOD;
}
}
return (dp[zero][one][0]+dp[zero][one][1])%MOD;
}
};
6. LC 3007 价值和小于等于K的最大数组
首先1-n中,这个n越大,累计和越大。所以是单调的,可以二分。
在计算1-mid中的累计和时,可以借鉴:
这题,套数位dp板子,枚举选或不选,只是要加上数位索引i%x==0的条件,以满足设置位为1的要求。
from functools import cache
class Solution:
def findMaximumNumber(self, k: int, x: int) -> int:
def sum(mid:int)->int:
@cache
def dfs(i:int,cnt:int,is_Limit:bool)->int:
if i==0:
return cnt
up = (mid>>(i-1))&1 if is_Limit else 1
res = 0
for d in range(up+1):
res += dfs(i-1,cnt+(d==1 and i%x==0),is_Limit and d==up)
return res
return dfs(mid.bit_length(),0,True)
l,r,ans = 1,(k+1)*256,1
while l<r:
mid = (l+r)>>1
if sum(mid)>k:
r = mid
else:
ans = mid
l = mid+1
return ans