题目描述
题解
首先把所有的S串都丢到trie树里,建立fail树。
每加进来一个T,把它在AC自动机上暴力匹配,匹配到的每一个点在fail树中到根的路径上出现过的S串end标记的S串都应该+1,也就是说,每一次求出匹配到的每一个点在fail树中到根的路径上出现过的end标记表示的S串,然后取并集,这些S串的答案应该+1.
由于有可能有重复计算,我们需要把所有匹配过的点按照dfs序排序,然后要消除相邻的两个点的lca到根的路径上多加的点的影响。如果是树链修改的话比较麻烦,可以转化为单点修改然后查询子树区间权值和,这样就可以用bit轻松实现了。
代码
#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdio>
#include<queue>
using namespace std;
#define S 2000005
#define N 100005
#define lg 21
char s[S];
int n,m,opt,x,sz,dfs_clock;
int ch[S][30],fail[S],is_end[N],f[S][lg+1],in[S],out[S],h[S],pt[S],C[S];
int tot,point[S],nxt[S],v[S];
queue <int> q;
void insert(int id)
{
int len=strlen(s),now=0;
for (int i=0;i<len;++i)
{
int x=s[i]-'a';
if (!ch[now][x]) ch[now][x]=++sz;
now=ch[now][x];
}
is_end[id]=now;
}
void make_fail()
{
while (!q.empty()) q.pop();
for (int i=0;i<26;++i)
if (ch[0][i]) q.push(ch[0][i]);
while (!q.empty())
{
int now=q.front();q.pop();
for (int i=0;i<26;++i)
{
if (!ch[now][i])
{
ch[now][i]=ch[fail[now]][i];
continue;
}
fail[ch[now][i]]=ch[fail[now]][i];
q.push(ch[now][i]);
}
}
}
void add(int x,int y)
{
// printf("%d %d\n",x,y);
++tot; nxt[tot]=point[x]; point[x]=tot; v[tot]=y;
}
void build(int x,int dep)
{
h[x]=dep;in[x]=++dfs_clock;
for (int i=1;i<lg;++i)
f[x][i]=f[f[x][i-1]][i-1];
for (int i=point[x];i;i=nxt[i])
{
f[v[i]][0]=x;
build(v[i],dep+1);
}
out[x]=dfs_clock;
}
int lca(int x,int y)
{
if (h[x]<h[y]) swap(x,y);
int k=h[x]-h[y];
for (int i=0;i<lg;++i)
if ((k>>i)&1) x=f[x][i];
if (x==y) return x;
for (int i=lg-1;i>=0;--i)
if (f[x][i]!=f[y][i])
x=f[x][i],y=f[y][i];
return f[x][0];
}
void ac()
{
int len=strlen(s),now=0;
for (int i=0;i<len;++i)
{
int x=s[i]-'a';
int y=ch[now][x];
pt[++pt[0]]=y;
now=y;
}
}
int cmp(int a,int b)
{
return in[a]<in[b];
}
void change(int loc,int val)
{
// printf("%d %d\n",loc,val);
for (int i=loc;i<=sz+1;i+=i&(-i))
C[i]+=val;
}
int query(int loc)
{
int ans=0;
for (int i=loc;i>=1;i-=i&(-i))
ans+=C[i];
return ans;
}
int main()
{
scanf("%d",&n);
for (int i=1;i<=n;++i)
{
scanf("%s",s);
insert(i);
}
make_fail();
for (int i=1;i<=sz;++i) add(fail[i],i);
build(0,1);
scanf("%d",&m);
for (int i=1;i<=m;++i)
{
scanf("%d",&opt);
if (opt==1)
{
scanf("%s",s);
pt[0]=0;ac();
if (!pt[0]) continue;
sort(pt+1,pt+pt[0]+1,cmp);
change(in[pt[1]],1);
for (int i=2;i<=pt[0];++i)
{
int r=lca(pt[i-1],pt[i]);
change(in[r],-1);
change(in[pt[i]],1);
}
}
else
{
scanf("%d",&x);x=is_end[x];
int ans=query(out[x])-query(in[x]-1);
printf("%d\n",ans);
}
}
}