坑死我了,其实我还是不太懂为什么一定要补上不存在的边?
但是这题的妙处就是可以由fail跑出所有的后缀来
#include <iostream>
#include <algorithm>
#include <queue>
#include <stack>
#include <cstdio>
#include <string>
#include <cstring>
#include <vector>
#include <set>
#include <cmath>
#include <map>
#define LL long long
#define INF 0x3f3f3f3f
#define mod 1000000007
const int maxn = 10000+5;
using namespace std;
char p[1000];
int n,k,l;
double dp[maxn][105];
bool vis[maxn][105];
double pos[1000];
struct Aho{
struct node{
int next[63];
int fail,cnt;
}state[maxn];
queue<int> q;
int size;
int idx(char ch) {
if(islower(ch))
return ch - 'a';
else if(isupper(ch))
return ch - 'A' + 26;
return ch - '0' + 52;
}
void init(){
while(!q.empty()) q.pop();
for(int i=0; i<maxn; i++){
memset(state[i].next, 0, sizeof(state[i].next));
state[i].fail = state[i].cnt = 0;
}
size = 1;
}
void insert(char *s){
int n = (int)strlen(s);
int now = 0;
for(int i=0; i<n; i++){
int c = idx(s[i]);
if(!state[now].next[c]){
state[now].next[c] = size++;
}
now = state[now].next[c];
}
state[now].cnt = 1;
}
void build(){
state[0].fail = -1;
q.push(0);//0是根节点
while(!q.empty()){
int u = q.front();
q.pop();
for(int i=0; i<62; i++){
if(state[u].next[i]){
if(u == 0) state[state[u].next[i]].fail = 0;
else{
int v = state[u].fail;//父亲的fail
while(v != -1){
if(state[v].next[i]){//如果该节点的儿子有这条边
state[state[u].next[i]].fail = state[v].next[i];
state[state[u].next[i]].cnt |= state[state[v].next[i]].cnt;
break;
}
v = state[v].fail;
}
if(v == -1)
state[state[u].next[i]].fail = 0;
}
q.push(state[u].next[i]);
}
else{//按照蓝书上的话说 是把不存在的fail也补上 导致match时可以不需要不断往上跳
if(u == 0) state[u].next[i] = 0;
else state[u].next[i] = state[state[u].fail].next[i];
}
}
}
}
double match(int u, int l){
if(!l) return 1.0;
if(vis[u][l]) return dp[u][l];
vis[u][l] = true;
double & ans = dp[u][l];
ans = 0.0;
for(int i=0; i<62; i++)
if(state[state[u].next[i]].cnt == 0) ans += pos[i] * match(state[u].next[i],l-1);
return ans;
}
}aho;
int main(){
int T, kases = 1;
scanf("%d",&T);
while(T--){
aho.init();
memset(vis, false, sizeof(vis));
memset(dp, 0, sizeof(dp));
memset(pos, 0, sizeof(pos));
scanf("%d",&k);
for(int i=0; i<k; i++){
scanf("%s",p);
aho.insert(p);
}
aho.build();
scanf("%d",&n);
char ch;
double ps;
for(int i=0; i<n; i++){
getchar();
scanf("%c %lf",&ch,&ps);
int c = aho.idx(ch);
pos[c] = ps;
}
scanf("%d",&l);
printf("Case #%d: %.6lf\n",kases++,aho.match(0,l));
}
}
/*
2
1
a
2
a 0.5
b 0.5
2
2
ab
ab
2
a 0.2
b 0.8
2
*/