题目:http://acm.zju.edu.cn/onlinejudge/showProblem.do?problemId=4317
题意:给定N个01串,再给定区间[a,b],问区间[a,b]里面有多少个数转化成BCD码之后不包含任何前面给出01串。
分析:首先将01串建ac自动机,然后把不可到达的点标记出来。用二维数组Matrix[][]把状态转移图(比如Matirx[cur][x]表示当前在ac机上cur位置,再加上一个数字x,跑到Matrix[cur][x]这个位置,-1表示不能走)建好,其实也可以不用建,建了方便dp用。剩下的就是数位dp了,从高位到低位填数字,dp数组:dp[len][cur][lim][zero]表示当前状态方案数,len表示当前在len位上选数字,cur表示当前在ac机的cur位置,lim标记是否可以在0~9里面任意选,zero表示前一位是否有前导0。(如果前一位是前导0,这意味着之前在ac机上的位置还是在root)。
举个例子:求0~4210满足条件的个数。
首先选最高位(第四位),这一位选0,其实相当于没选。
这一位选1,那么剩下的3位可以任意选。
这一位选2,那么剩下的3位可以任意选。
这一位选3,那么剩下的3位可以任意选。
这一位选4,那么下一位(第三位)只能选0或1或2。
依次类推。。。。。。
状态数len*cur*lim*zero≈ 1 000 000。
代码:
#include <iostream>
#include <cstdio>
#include <cstring>
#include <queue>
using namespace std;
typedef long long LL;
const LL mod = 1000000009;
const int maxn = 2015;
const int kd = 2;
struct trie
{
int son[maxn][kd],fail[maxn],fbd[maxn];
int cnt,root;
int dp[203][maxn][2][2];
int Matrix[maxn][10];
char s[maxn];
inline int newnode()
{
fill(son[cnt],son[cnt]+kd,-1);
fbd[cnt]=0;
return cnt++;
}
void Init()
{
cnt=0;
root=newnode();
}
void Insert(char *str)
{
int i,index,now=root;
for(i=0;str[i];i++)
{
index=str[i]-'0';
if(son[now][index]==-1)
son[now][index]=newnode();
now=son[now][index];
}
fbd[now]=1; //字符串结点禁止走
}
void findfail() //建ac机,找禁止走的位置
{
queue <int > q;
int i,temp;
fail[root]=root;
for(i=0;i<kd;i++)
if(son[root][i]==-1)
son[root][i]=root;
else
{
fail[son[root][i]]=root;
q.push(son[root][i]);
}
while(!q.empty())
{
temp=q.front();
q.pop();
fbd[temp]|=fbd[fail[temp]]; //fail指针指向的位置禁止走,也禁止走
for(i=0;i<kd;i++)
if(son[temp][i]==-1)
son[temp][i]=son[fail[temp]][i];
else
{
fail[son[temp][i]]=son[fail[temp]][i];
q.push(son[temp][i]);
}
}
}
int Next(int cur,int x) //状态转移 cur+x --> new
{
if(fbd[cur])
return -1;
for(int i=3;i>=0;i--)
{
if(fbd[son[cur][(x>>i)&1]])
return -1;
cur=son[cur][(x>>i)&1];
}
return cur;
}
void Prepare() //建状态转移图
{
for(int i=0;i<cnt;i++)
for(int j=0;j<10;j++)
Matrix[i][j]=Next(i,j);
}
LL dfs(int len,int cur,int lim,int zero) //数位dp,记忆化搜索
{
if(len==-1)
return 1;
if(dp[len][cur][lim][zero]!=-1)
return dp[len][cur][lim][zero];
int ret=0,tl,tz,END=(lim?s[len]-'0':9);
for(int i=0;i<=END;i++)
{
tl=(lim&&(i==END));
tz=(zero&&!i);
if(tz)
{
ret+=dfs(len-1,root,tl,tz);
if(ret>=mod)
ret-=mod;
continue ;
}
if(Matrix[cur][i]==-1)
continue ;
ret+=dfs(len-1,Matrix[cur][i],tl,tz);
if(ret>=mod)
ret-=mod;
}
return dp[len][cur][lim][zero]=ret;
}
void Sone(char *str) //减1
{
int len=strlen(str);
for(int i=len-1;i>=0;i--)
{
if(str[i]>'0')
{
--str[i];
break;
}
str[i]='9';
}
}
void Reversal(char *str) //反转,去前导零
{
int len=strlen(str);
for(int i=0,j=len-1;i<=len/2-1;i++,j--)
swap(str[i],str[j]);
for(int j=len-1;j>=0 && (str[j]=='0');j--)
str[j]='\0';
}
LL solve(char *s1,char *s2) //求0~s1-1的解,0~s2的解,相减
{
LL ans=0,len;
Sone(s1);
Reversal(s1);
len=strlen(s1);
strcpy(s,s1);
memset(dp,-1,sizeof(dp));
ans-=dfs(len-1,root,1,1);
Reversal(s2);
len=strlen(s2);
strcpy(s,s2);
memset(dp,-1,sizeof(dp));
ans+=dfs(len-1,root,1,1);
return (ans+mod)%mod;
}
}ac;
char ss1[maxn],ss2[maxn];
int main()
{
int ncase,n,i,j;
scanf("%d",&ncase);
while(ncase--)
{
scanf("%d",&n);
ac.Init();
while(n--)
{
scanf("%s",ac.s);
ac.Insert(ac.s);
}
ac.findfail();
ac.Prepare();
scanf("%s%s",ss1,ss2);
printf("%lld\n",ac.solve(ss1,ss2));
}
return 0;
}