和HDU 2089 不要62 初探数位dp这道题思路类似,先将dp数组预处理,再利用dp数组的结果计算[0, n)区间符合条件的数字的个数。
定义状态dp[i][j][k][l]表示以j开头的i位数字中模13余l的数字个数,k = 0表示这些数字不包含13, k = 0表示包含13。
然后按照相同的思路计算[0, n)内满足条件的数字个数,只不过要维护一个状态变量flag,flag = 0表示前面不包含13,flag = 1表示前面包含13。具体方法不再赘述,参见博客开头的链接。
另外要注意在计算过程中会爆int,所以作者写了一个快速幂取模函数mod_pow。
代码如下:
#include <iostream>
#include <cstring>
#include <cstdlib>
#include <cstdio>
#include <cmath>
using namespace std;
typedef long long int ll;
int dp[15][15][2][15];
int digit[15];
int mod_pow(int a, int n)
{
int res = 1;
while (n)
{
if (n & 1)
res = (res * a) % 13;
a = (a * a) % 13;
n >>= 1;
}
return res;
}
void init()
{
memset(dp, 0, sizeof(dp));
for (int i = 0; i <= 9; i++)
dp[1][i][0][i] = 1;
for (int i = 2; i <= 10; i++)
for (int j = 0; j <= 9; j++)
for (int flag = 0; flag <= 1; flag++)
for (int l = 0; l <= 12; l++)
for (int k = 0; k <= 9; k++)
{
if (!flag)
{
if (!(j == 1 && k == 3))
dp[i][j][flag][l] += dp[i - 1][k][flag][(l + 13 - (j * mod_pow(10, i - 1)) % 13) % 13];
}
else
{
dp[i][j][flag][l] += dp[i - 1][k][flag][(l + 13 - (j * mod_pow(10, i - 1)) % 13) % 13];
if (j == 1 && k == 3)
dp[i][j][flag][l] += dp[i - 1][k][!flag][(l + 13 - (j * mod_pow(10, i - 1)) % 13) % 13];
}
}
}
// 计算[0, n)中有多少个包含13,同时又是13倍数的数字
int Count(int n)
{
memset(digit, 0, sizeof(digit));
int len = 0;
while (n)
{
digit[++len] = n % 10;
n /= 10;
}
int ans = 0, flag = 0, m = 0;
for (int i = len; i >= 1; i--)
{
for (int j = 0; j < digit[i]; j++)
{
ans += dp[i][j][1][(13 - m) % 13];
if (flag == 1)
ans += dp[i][j][0][(13 - m) % 13];
else if (digit[i + 1] == 1 && j == 3)
ans += dp[i][j][0][(13 - m) % 13];
}
m = (m + digit[i] * mod_pow(10, i - 1)) % 13;
if (digit[i + 1] == 1 && digit[i] == 3)
flag = 1;
}
return ans;
}
int main()
{
//freopen("test.txt", "r", stdin);
int n;
init();
while (~scanf("%d", &n))
{
printf("%d\n", Count(n + 1));
}
return 0;
}
dfs的写法确实比递推要简洁。
#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdlib>
#include <cstdio>
using namespace std;
int dp[15][15][3];
int digit[10];
int dfs(int pos, int mod, int flag, int limit)
{
if (!pos)
return !mod && flag == 2;
if (!limit && dp[pos][mod][flag] != -1)
return dp[pos][mod][flag];
int up = limit ? digit[pos] : 9;
int ans = 0;
for (int i = 0; i <= up; i++)
{
int temp = (mod * 10 + i) % 13;
if (flag == 2)
ans += dfs(pos - 1, temp, flag, limit && i == up);
else if (flag == 1)
{
if (i == 3)
ans += dfs(pos - 1, temp, 2, limit && i == up);
else
ans += dfs(pos - 1, temp, i == 1, limit && i == up);
}
else if (flag == 0)
ans += dfs(pos - 1, temp, i == 1, limit && i == up);
}
return limit ? ans : dp[pos][mod][flag] = ans;
}
int cal(int n)
{
int len = 0;
while (n)
{
digit[++len] = n % 10;
n /= 10;
}
return dfs(len, 0, 0, 1);
}
int main()
{
int n;
while (~scanf("%d", &n))
{
memset(dp, -1, sizeof(dp));
printf("%d\n", cal(n));
}
return 0;
}