题目:给你N个字典单词s,Q个前缀s1,后缀s2,保证s1和s2不想交,问通过s1,s2能在字典中和多少个单词匹配。
思路:假设单词为abcd,前缀为ab,后缀为cd,那么我们将单词变成abcd#abcd,前后缀变为cd#ab,那么我们就可以通过AC自动机解决这道题目了
代码:
#pragma comment(linker, "/STACK:1024000000,1024000000")
#include<iostream>
#include<algorithm>
#include<cstdio>
#include<cmath>
#include<cstring>
#include<string>
#include<vector>
#include<map>
#include<bitset>
#include<set>
#include<queue>
#include<stack>
#include<list>
#include<numeric>
using namespace std;
#define LL long long
#define ULL unsigned long long
#define INF 0x3f3f3f3f
#define mm(a,b) memset(a,b,sizeof(a))
#define PP puts("*********************");
template<class T> T f_abs(T a){ return a > 0 ? a : -a; }
template<class T> T gcd(T a, T b){ return b ? gcd(b, a%b) : a; }
template<class T> T lcm(T a,T b){return a/gcd(a,b)*b;}
// 0x3f3f3f3f3f3f3f3f
// 0x3f3f3f3f
const int SIGMA_SIZE = 27;
const int MAXNODE = 1e6+50;
const int maxn=1e5+50;
int ch[MAXNODE][SIGMA_SIZE];
int f[MAXNODE];
int val[MAXNODE];
int last[MAXNODE];
int cnt[MAXNODE],dep[MAXNODE];
int pos[maxn],len[maxn];
int sz;
void AC_init(){
sz=1;
mm(ch[0],0);
mm(cnt,0);
dep[0]=0;
}
void AC_insert(string s,int v){
int u=0,n=s.size();
for(int i=0;i<n;i++){
int c=s[i]-'a';
if(!ch[u][c]){
mm(ch[sz],0);
val[sz]=0;
dep[sz]=dep[u]+1;
ch[u][c]=sz++;
}
u=ch[u][c];
}
val[u]=1;
pos[v]=u;
}
int AC_find(string T,int L){
int n=T.size();
int j=0;
for(int i=0;i<n;i++){
int c=T[i]-'a';
j=ch[j][c];
while(dep[j]>L)
j=f[j];
cnt[last[j]]++;
}
}
int num[MAXNODE],stk[MAXNODE];
void AC_count(){
int top=0;
mm(num,0);
for(int i=0;i<sz;i++) num[f[i]]++;
for(int i=0;i<sz;i++) if(!num[i]) stk[top++]=i;
for(int i=0;i<top;i++){
int j=f[stk[i]];
cnt[j]+=cnt[stk[i]];
if((--num[j])==0)
stk[top++]=j;
}
}
void AC_getFail(){
queue<int>q;
f[0]=0;
for(int c=0;c<SIGMA_SIZE;c++){
int u=ch[0][c];
if(u){
f[u]=0;q.push(u);last[u]=0;
}
}
while(!q.empty()){
int r=q.front();q.pop();
for(int c=0;c<SIGMA_SIZE;c++){
int u=ch[r][c];
if(!u){
ch[r][c]=ch[f[r]][c];continue;
}
q.push(u);
int v=f[r];
while(v&&!ch[v][c]) v=f[v];
f[u]=ch[v][c];
last[u]=val[u]?u:last[f[u]];
}
}
}
string s[maxn];
string s1,s2;
int main(){
int T,N,Q;
scanf("%d",&T);
while(T--){
AC_init();
scanf("%d%d",&N,&Q);
for(int i=1;i<=N;i++){
cin>>s[i];
len[i]=s[i].size()+1;//这是防止重复覆盖要用的
string temp=s[i];
s[i]+='z'+1;
s[i]+=temp;
}
for(int i=1;i<=Q;i++){
cin>>s1>>s2;
s2+='z'+1;
s2+=s1;
AC_insert(s2,i);
}
AC_getFail();
for(int i=1;i<=N;i++)
AC_find(s[i],len[i]);
AC_count();
for(int i=1;i<=Q;i++)
printf("%d\n",cnt[pos[i]]);
}
return 0;
}