ZOJ 3494 BCD Code(AC自动机+数位DP)
http://acm.zju.edu.cn/onlinejudge/showProblem.do?problemCode=3494
题意:给你一个数据区间[A,B],问你该区间内有多少个数字转化成BCD码之后不含被禁止的01模板串.
分析:首先本题数位DP详见论文<<初探数位DP>>
本题首先用所有模板建立AC自动机,然后求出数组bcd[i][num]=j表示当前在自动机的i节点,然后走过数字(0-9)的BCD 4位二进制码之后到达的状态是j.
然后用bcd数组来做数位DP递推求出sum(a)表示区间[0,a]内所有的要求数据,然后求出区间[0,b]内的合法数据,两者相减.
AC代码:
#include<cstdio>
#include<cstring>
#include<queue>
using namespace std;
typedef long long LL;
const int maxnode=2000+100;
const int sigma_size=2;
const int MOD=1000000009;
/**other_code**/
int bcd[2005][10]; //bcd[i][j]表示在结点i,经过一个数字j,到达的自动机结点编号
LL dp[205][2005]; //dp[i][j]表示长度为i,位于结点j的个数(应该是长度为i的合法串)
//dp[len][i]表示长为len的串(可能为全0,即允许前导0)且在自动机的i号节点,这样的串共有多少个
int bit[205],len,n;//bit[i]表示大数的第i位是多少(0-9,十进制)
//数位DP,长度为len,当前状态为pos,是否有限制,是否有前导0
LL dfs(int len,int pos,bool limit,bool zero)//limit表示是否有上界,如果limit为true,该位置len能枚举的数就不是0-9而是0-5或0-3?
{
if(len==0) return 1;
if(!limit&&dp[len][pos]!=-1) return dp[len][pos];
LL ans=0;
//如果之前全为0,但是由于0是不能计算的,所以当前不为最低位
if(len>1&&zero)
{
ans+=dfs(len-1,pos,limit&&bit[len]==0,true);
if(ans>=MOD) ans-=MOD;
}
else
{
//判断转移是否合法
if(bcd[pos][0]!=-1) ans+=dfs(len-1,bcd[pos][0],limit&&bit[len]==0,false);
if(ans>=MOD) ans-=MOD;
}
int up=limit?bit[len]:9;
for(int i=1;i<=up;i++)
{
if(bcd[pos][i]!=-1)
{
ans+=dfs(len-1,bcd[pos][i],limit&&i==up,false);
if(ans>=MOD) ans-=MOD;
}
}
if(!limit&&!zero) dp[len][pos]=ans;
return ans;
}
LL cal(char *s,int l)//将大数"99900"s逆序放到bit中"00999"
{
memset(dp,-1,sizeof(dp));
for(int i=1;i<=l;i++) bit[l-i+1]=s[i-1]-'0';
return dfs(l,0,true,true);
}
char A[205],B[205];
//高精度-1,这样会遗留前导0,无所谓了。。。
void sub(char *s,int len)//大数s-1之后的结果
{
for(int i=len-1;i>=0;i--)
{
if(s[i]=='0') s[i]='9';
else
{
s[i]--;
break;
}
}
}
/**other_code**/
struct AC_Automata
{
int ch[maxnode][sigma_size];
int f[maxnode];
int match[maxnode];
int sz;
void init()
{
sz=1;
memset(ch[0],0,sizeof(ch[0]));
f[0]=match[0]=0;
}
void insert(char *s)
{
int n=strlen(s),u=0;
for(int i=0;i<n;i++)
{
int id=s[i]-'0';
if(ch[u][id]==0)
{
ch[u][id]=sz;
memset(ch[sz],0,sizeof(ch[sz]));
match[sz++]=0;
}
u=ch[u][id];
}
match[u]=1;
}
void getFail()
{
queue<int> q;
for(int i=0;i<sigma_size;i++)
{
int u=ch[0][i];
if(u)
{
f[u]=0;
q.push(u);
}
}
while(!q.empty())
{
int r=q.front();q.pop();
for(int i=0;i<sigma_size;i++)
{
int u=ch[r][i];
if(!u) { ch[r][i]=ch[f[r]][i]; continue; }
q.push(u);
int v=f[r];
while(v && ch[v][i]==0) v=f[v];
f[u]=ch[v][i];
match[u] |= match[f[u]];
}
}
}
int BCD(int st,int num)
{
if(match[st]==1) return -1;
int u=st;
for(int i=3;i>=0;i--)
{
int id=(num>>i)&1;
u=ch[u][id];
if(match[u]==1) return -1;
}
return u;
}
void get_bcd()
{
for(int i=0;i<sz;i++)
for(int num=0;num<10;num++)
bcd[i][num]=BCD(i,num);
}
}ac;
int main()
{
int T;
scanf("%d",&T);
while(T--)
{
ac.init();
int n;
scanf("%d",&n);
while(n--)
{
char str[100];
scanf("%s",str);
ac.insert(str);
}
ac.getFail();
ac.get_bcd();
scanf("%s",A);
sub(A,strlen(A));
LL ans = -cal(A,strlen(A));
scanf("%s",B);
ans += cal(B,strlen(B));
printf("%lld\n",(ans%MOD+MOD)%MOD);
}
}