两道例题:[USACO12JAN]Video Game G - 洛谷 和 [JSOI2007]文本生成器 - 洛谷
先说ac自动机上dp的一般形式:
dp数组一般都是f[i][j]的形式,表示递推到文本串的前i个字符,目前在自动机的j号点上
转移形式是枚举当前遍历到前i个字符,当前在u号结点上,然后a~z的26种转移。具体看例题代码
【1】第一道
#include <bits/stdc++.h>
using namespace std;
#define FOR(i,a,b) for(int i=(a), (i##i)=(b); i<=(i##i); ++i)
const int N = 305;
string s;
int n, k, ch[N][3], fail[N], tot, en[N];
int f[1005][N];
void getmx(int&a,int b){a=max(a,b);}
void ins(const string&s){
int len=s.size(), u=0;
for(int i=0; i<len; i++){
int x=s[i]-'A';
if(!ch[u][x]) ch[u][x]=++tot;
u=ch[u][x];
}
en[u]++; //标记结束
}
void get_fail(){
queue<int> q;
for(int i=0; i<3; i++)
if(ch[0][i]) q.push(ch[0][i]);
while(q.size()){
int u=q.front(); q.pop();
for(int i=0; i<3; i++){
if(ch[u][i]) fail[ch[u][i]] = ch[fail[u]][i], q.push(ch[u][i]);
else ch[u][i]=ch[fail[u]][i];
}
en[u] += en[fail[u]]; //fail链上所有的en值累加到u上
}
}
inline void solve(){
cin>>n>>k;
FOR(i,1,n) cin>>s, ins(s); //建trie
memset(f,-1,sizeof(f)); f[0][0]=0;
get_fail(); //构建fail树
for(int i=0; i<k; i++)
for(int u=0; u<=tot; u++){
if(f[i][u]==-1) continue; //没有到过,不能对后面产生贡献
for(int c=0; c<3; c++){
int v=ch[u][c];
getmx(f[i+1][v], f[i][u]+en[v]);
}
}
int ans=0;
FOR(i,0,tot) getmx(ans, f[k][i]);
cout<<ans;
}
int main(){
ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
int T=1; //cin>>T;
while(T--) solve();
}
【2】第二道
#include <bits/stdc++.h>
using namespace std;
#define FOR(i,a,b) for(int i=(a), (i##i)=(b); i<=(i##i); ++i)
const int N = 1e4+5, mod = 1e4+7;
string s;
int n, m, ch[N][26], fail[N], tot, en[N];
int f[105][N]; //f[i][j]表示前i个字符,目前在自动机j号点时的不可读数量
//总的不可读数量是 sum(f[m][i]) (0<=i<=tot)
void add(int&a,int b){a=(a+b)%mod;}
void ins(const string&s){
int len=s.size(), u=0;
for(int i=0; i<len; i++){
int x=s[i]-'A';
if(!ch[u][x]) ch[u][x]=++tot;
u=ch[u][x];
}
en[u]++; //标记结束
}
void get_fail(){
queue<int> q;
for(int i=0; i<26; i++)
if(ch[0][i]) q.push(ch[0][i]);
while(q.size()){
int u=q.front(); q.pop();
for(int i=0; i<26; i++){
if(ch[u][i]) fail[ch[u][i]] = ch[fail[u]][i], q.push(ch[u][i]);
else ch[u][i]=ch[fail[u]][i];
}
en[u] += en[fail[u]]; //fail链上所有的en值累加到u上
}
}
inline void solve(){
cin>>n>>m;
FOR(i,1,n) cin>>s, ins(s);
get_fail();
f[0][0]=1; //init
FOR(i,0,m-1) FOR(u,0,tot){
FOR(c,0,25){ //26个转移路线
int v=ch[u][c];
if(!en[v]) add(f[i+1][v], f[i][u]);
}
}
int sum=0, ans=1;
FOR(i,1,m) ans=ans*26%mod; //慢速幂
FOR(u,0,tot) add(sum, f[m][u]);
cout<<(ans-sum+mod)%mod;
}
signed main(){
ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
int T=1; //cin>>T;
while(T--) solve();
}