题目大意
有好多串字符串 任选两个(可以相同) 再分别从两个字符串中任意截取一个前缀组合成一个新的字符串
问:有多少种不同的组合?
若是两串完全不同的串,则方案数为 (len[A]+len[B])2 ( l e n [ A ] + l e n [ B ] ) 2 表示两个串的长度之和的平方
例如:ab 和 c
方案数为 (2+1)*(2+1)
方案为
aa
aba
abab
aab
cc
ac
abc
ca
cab
代表的就是每一个前缀和每一个前缀(可相同)组合
但是怕的就是 有相同的部分
例如: abc 与 bcd 组合
组合出 abcd 的方案共两种 abc-d、a-bcd
那么这个组合就应该减去一种
换句话说 abc 的后缀 与 bcd 的前缀 有一个公共部分 bc-bc ,在这个bc-bc中任选一个部分都可以组成原组合
那么我们就需要统计这些公共部分的出现次数
那么我们就需要一个叫做AC自动机的东西
在自动机中 匹配失配后的
fail
f
a
i
l
到的下一个点即是与以当前字符串的部分后缀为前缀的字符串
比如 abcd----fail---->bcd
在AC自动机中 bcd一定是一个单独的字符串的前缀部分
而 bcd 恰好又是 abcd 的后缀(abcd不一定是一个完整的字符串,它只是一部分前缀,这也正好符合了题意)
于是乎,我们就可以沿着
fail
f
a
i
l
往上跳,并且统计
每一个被
fail
f
a
i
l
指向的字符串都是当前前缀的后缀
而
fail
f
a
i
l
的
fail
f
a
i
l
也相应地是当前前缀的后缀,只是更短而已
所以我们可以统计一个
sum
s
u
m
,表示以当前字符串为后缀的前缀个数
转移时就可以通过
fail
f
a
i
l
来继承(详见代码)
而对于每一个前缀
S
S
我们通过它的 不断向上跳获得它的最短后缀
然后我们就能够其可移动的部分
求解
对于每一个前缀 我们都
fail
f
a
i
l
一次 即:获得它的后半部分
然后我们移动到它的前半部分(除去后半部分)
例如 对于 bc 、 abcbc
当我们讨论abcbc时
fail到的后半部分应该是 bc
于是 当前的前半部分就是 abc
然后我们就开始讨论以当前前缀的前半部分为的字符串前缀的后半部分的数量
例如 上述例子中的abcbc的前半部分为abc
那么我们要寻找的就是如同 cabc tabc babc ababc 这样以abc为后半部分的前缀(并非完整字符串)
而这个数量就是当前字符串的
sum
s
u
m
值(见定义)
统计到
delta
d
e
l
t
a
中即可
于是结果为 cnt2−delta c n t 2 − d e l t a (cnt为AC自动机点数之和)
打完之后才发现代码中每一次都是讨论的都是前缀,但其实原理是一样的,都是把和自己有公共部分的删掉,只留一个最短的,
代码+注释(本人的代码 可能有错)
#include <iostream>
#include <cstdio>
#include <queue>
#include <cstring>
using namespace std;
int fa[300123],len[300123],deg[300123];
int cnt,ch[300123][26],fail[300123];
long long sum[300123];
char s[300123];
queue<int>q;
struct ACM
{
void init()
{
memset(fa,0,sizeof(fa));
memset(sum,0,sizeof(sum));
memset(deg,0,sizeof(deg));
memset(fail,0,sizeof(fail));
memset(ch,0,sizeof(ch));
cnt=0;
}
//建立Trie
void insert(char *s)
{
int l=strlen(s),c;
for(int i=0;i<l;i++)
{
c=s[i]-'a';
if(ch[p][c])p=ch[p][c];
else
{
fa[++cnt]=p;
len[cnt]=len[p]+1;
p=ch[p][c]=c;
}
}
}
//建立AC自动机
void Build()
{
for(int i=0;i<26;i++)if(ch[0][i])q.push(ch[0][i]);
int x;
while(!q.empty())
{
x=q.front();q.pop();
for(int i=0;i<26;i++)
if(ch[x][i])
{
fail[ch[x][i]]=ch[fail[x]][i];
q.push(ch[x][i]);
}
else ch[x][i]=ch[fail[x]][i];
}
for(int i=1;i<=cnt;i++)sum[i]=1;
for(int i=1;i<=cnt;i++)
if(fail[i])deg[fail[i]]++;//统计连接到每一个点的fail边数量
//用类似TOP排序的方法计算sum
for(int i=1;i<=cnt;i++)
if(!deg[i])q.push(i);
while(!q.empty())
{
x=q.front();
sum[fail[x]]+=sum[x];
//只有当所有连向某个点的fail边都计算过了才能对那个点进行下一步计算
if(!--deg[fail[x]])q.push(fail[x]);
}
}
void solve()
{
long long delta=0;
for(int i=1;i<=cnt;i++)
{
int temp=len[fail[i]],p=i;//找到第一个后缀
while(temp--)p=fa[p];//反向找到前缀
delta+=(fail[i]!=0)*(sum[p]-1);//sum中只保留一个
}
printf("%lld\n",1LL*cnt*cnt-delta);
}
}AC;
int main()
{
while(1)
{
scanf("%d",n);
while(n--)
{
scanf("%s",s);
AC.insert(s);
}
AC.Build();
AC.solve();
}
return 0;
}