题目链接
不看代码:
树状数组+dfs序列ORdfs维护子树大小(这个子树大小指的是叶子节点个数),对于一个ac自动机来说,建fail后,其父节点状态为子节点状态的后缀。反过来说,当前结点的子树大小为拥有该后缀的字符串个数。
所以把所有串丢到ac自动机,计算每个子树大小。
对于aaabbb,我们跑ac自动机,之后计算每个状态其子树的贡献。具体细节可以想想。主要是这个去重,这个我没想出kmp来搞。
例如aaaaa和baaaa,如果能匹配就是aaaaa和baaaa。按着上面的方法跑了前4个贡献是a,aa,aaa,aaaa,这时候可以想想kmp的nxt指针,指向上一个匹配的位置,aaaa上一个位置是aaa,因为要删了a,aa,aaa,这个时候可以开始思考怎么用kmp删了。然后就放代码了叭,上面的想不出来再来看看代码。
====================================================
代码版本
#include<bits/stdc++.h>
using namespace std;
using namespace std;
const int N=1000006;
const int maxn=1000006;
int nxt[maxn];
void gnxt(char *a,int len)
{
int i,j;
nxt[0]=-1;i=0,j=-1;
while(i<len)
{
if(j==-1||a[i]==a[j]){
i++;j++;
nxt[i]=j;
}else j=nxt[j];
}
}
struct node
{
int ch[N][26];
int fail[N];
int tot;
int cnt[N];
void init(){
memset(ch[0],0,sizeof ch[0]);
tot=0;fail[0]=cnt[0]=0;tot++;
}
void ins(char *str)
{
int L=strlen(str);
int rt=0;
for(int i=0;i<L;i++){
int c=str[i]-'a';
if(!ch[rt][c]){
memset(ch[tot],0,sizeof ch[tot]);
ch[rt][c]=tot;fail[tot]=0;cnt[tot]=0;tot++;
}
rt=ch[rt][c];
}
cnt[rt]++;
}
void build()
{
queue<int> que;
for(int i=0;i<26;i++)if(ch[0][i])que.push(ch[0][i]),fail[ch[0][i]]=0;
while(!que.empty())
{
int u=que.front();que.pop();
for(int i=0;i<26;i++){
if(!ch[u][i]){
ch[u][i]=ch[fail[u]][i];
}else {
fail[ch[u][i]]=ch[fail[u]][i];
que.push(ch[u][i]);
}
}
}
}
}ac;
int sum[maxn];
int tim;
void add(int pos,int val){
while(pos<=tim){
sum[pos]+=val;
pos+=(pos&(-pos));
}
}
int all(int pos)
{
int ans=0;
while(pos>=1){
ans+=sum[pos];
pos-=(pos&(-pos));
}
return ans;
}
int in[maxn],out[maxn];
struct edge
{
int v,nxt;
}e[maxn<<1];
int cnt;
int head[maxn];
void add_edge(int u,int v)
{
e[cnt].v=v;e[cnt].nxt=head[u];
head[u]=cnt++;
}
void dfs(int u,int f)
{
in[u]=++tim;
for(int i=head[u];~i;i=e[i].nxt){
int v=e[i].v;
if(v==f)continue;
dfs(v,u);
}
out[u]=tim;
}
void dfs1(int u,int f)
{
if(ac.cnt[u])
add(in[u],ac.cnt[u]);
for(int i=head[u];~i;i=e[i].nxt){
int v=e[i].v;if(v==f)continue;
dfs1(v,u);
}
}
char str[maxn];
int b[maxn];
int id[maxn];
int rt[maxn];
int fa[maxn];
char s[maxn];
int be[maxn];
int L[maxn];
int main()
{
int n,m;
scanf("%d",&n);
ac.init();
int mod=998244353;
int fi=0;
long long ans=0;
for(int i=1;i<=n;i++){
scanf("%s",str+fi);
int len=strlen(str+fi);
L[i]=len;
be[i]=fi;
ac.ins(str+fi);
fi+=len;
}
ac.build();
memset(head,-1,sizeof head);
for(int i=1;i<ac.tot;i++){
add_edge(ac.fail[i],i);
}
dfs(0,-1);
dfs1(0,-1);
for(int i=1;i<=n;i++){
int rt=0;
long long sz=0;
gnxt(str+be[i],L[i]);
for(int j=be[i];j<be[i]+L[i];j++){
rt=ac.ch[rt][str[j]-'a'];
sz++;
ans+=(1ll*sz*sz-1ll*nxt[j-be[i]+1]*nxt[j-be[i]+1])%mod*(all(out[rt])-all(in[rt]-1))%mod;
if(ans<0){
ans+=mod;
}
ans%=mod;
}
}
cout<<ans<<endl;
}