@(ACM题目)[字符串, AC自动机,fail树]
Description
某人读论文,一篇论文是由许多单词组成。但他发现一个单词会在论文中出现很多次,现在想知道每个单词分别在论文中出现多少次。
Input
第一个一个整数N,表示有多少个单词,接下来N行每行一个单词。每个单词由小写字母组成,N<=200,单词长度不超过10^6
Output
输出N个整数,第i行的数字表示第i个单词在文章中出现了多少次。
Sample Input
3
a
aa
aaa
Sample Output
6
3
1
题目分析
为了叙述方便,将从根节点到一个结点u所代表的字符串称为
首先将所有单词加入AC自动机,并统计每个结点u的路径字符串在单词表的中的单词前缀中出现的次数
对Trie树中的每个结点u,将它与
在这棵树中,对于每对父子(par,son),spar是sson的后缀,spar在sson中出现。所以,对于一个结点u,要统计
代码
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
namespace ACautomaton
{
const int maxn = 1e6 + 5;
const int maxm = 26;
int ch[maxn][maxm];//ch[i][c]代表结点i的c孩子;初始有一个根节点,代表空字符串
// int val[maxn];//val为正代表这是一个模式串单词结点
int fail[maxn];//suffix link,代表当前路径字符串的最大前缀
// int last[maxn];//output link, 上一个单词结点
int tot;//Trie树中结点总数
int sz;//单词总数
int pos[205];//pos[i]为单词i在trie树中的对应结点
int cnt[maxn];//统计该结点的路径字符串出现的次数
int to[maxn], nxt[maxn], head[maxn], nume;//fail树
void init()
{
sz = 0;
tot = 1;
// val[0] = 0;
memset(ch[0], 0, sizeof ch[0]);
nume = 0;//零条边
memset(head, 0xff, sizeof head);
}
void addEdge(int u, int v)
{
to[nume] = v;
nxt[nume] = head[u];
head[u] = nume ++;
}
//O(n),n为所有模式总长度
void add(char *P, int v)//插入模式串,值为v
{
int u = 0;//当前结点
int n = strlen(P);
for(int i = 0; i < n; ++i)
{
int c = P[i] - 'a';
if(!ch[u][c])//若当前结点无c孩子,则创造一个
{
memset(ch[tot], 0, sizeof ch[tot]);
// val[tot] = 0;//中间结点的值为零
ch[u][c] = tot++;
}
u = ch[u][c];//走向当前结点的c孩子
++cnt[u];
}
//现在走到了模式串的结尾结点
// val[u] += v;
pos[sz++] = u;
}
//O(tot)的
void getFail()//构造fail指针和last指针
//使用BFS,因为fail指针一定指向长度更短的字符串
{
queue<int> q;
fail[0] = 0;
//初始化队列
for(int c = 0; c < maxm; ++c)
{
int u = ch[0][c];
if(u)
{
fail[u] = 0;//第一层结点的fail都是根节点
// last[u] = 0;
addEdge(0, u);
q.push(u);//将第一层结点加入队列
}
}
//BFS
while(!q.empty())
{
int cur = q.front();
q.pop();
for(int c = 0; c < maxm; ++c)//为cur结点的c孩子添加fail指针
{
int u = ch[cur][c];
if(!u)//当前结点没有c孩子
{
ch[cur][c] = ch[fail[cur]][c];//沿fail往上找,因为fail指针指向的还是这个后缀
continue;
}
q.push(u);//c孩子入队
int v = fail[cur];
while(v && !ch[v][c]) v = fail[v];//若后缀结点无c孩子,就沿fail指针一直网上找
fail[u] = ch[v][c];//给c孩子添加fail指针
addEdge(fail[u], u);
//若c孩子的fail指针指向模式串结点,则c孩子的last指向fail指针位置即可,因为这就是最长的
//否则指向fail指针指向的结点的last即可
// if(val[fail[u]]) last[u] = fail[u];
// else last[u] = last[fail[u]];
}
}
}
void dfs(int cur)
{
for(int i = head[cur]; ~i; i = nxt[i])
{
dfs(to[i]);
cnt[cur] += cnt[to[i]];
// cout << "cur = " << cur << "; son = " << to[i] << endl;
}
}
};
const int maxn = 1e6 + 5;
char s[maxn];
int main()
{
int n;
cin >> n;
ACautomaton::init();
for(int i = 0; i < n; ++i)
{
scanf("%s", s);
ACautomaton::add(s, 1);
}
ACautomaton::getFail();
ACautomaton::dfs(0);
for(int i = 0; i < n; ++i)
{
using namespace ACautomaton;
printf("%d\n", cnt[pos[i]]);
}
return 0;
}