AC自动机

目录

一、前言

二、AC自动机的基础

三、AC自动机原理与实现

(一)引入

(二)所研究问题

(三)Trie Tree

1. 建立结构体

2. 构造字典树

(二)建立FAIL

(三)匹配过程

四、例题

(一) 题目链接

(二)讲解

五、习题练习


一、前言

说道AC,相信许多人都会想到题目正确的AC,但是AC自动机的AC和题目中的AC可谓是八竿子打不着。

AC自动机英文名为Aho-Corasick automaton,这个算法诞生于1975年,AC就是两位发明者的姓氏啦!

二、AC自动机的基础

想学习AC自动机之前,你必须了解这俩个东东:KMP算法,字典树Trie Tree。

你当然可以自行上网查询,本文就讲解AC自动机:KMP算法      字典树Trie Tree​​​​​​

三、AC自动机原理与实现

(一)引入

我先说KMP算法,他用于解决单个模式串与文本串的问题。而这里的AC自动机,则是研究多个模式串与文本串的问题,可以说AC自动机是KMP算法的进阶版。

(二)所研究问题

给定 n 个模式串 s_i 和一个文本串 t,求有多少个不同的模式串在文本串里出现过。两个模式串不同当且仅当他们编号不同。

(三)Trie Tree

对于多个模式串,并且还要查询(判断时),暴力的时间复杂度为O(n ),而Trie Tree的时间复杂度则会大大减少。

这也是为什么第一步要建立Trie Tree。

1. 建立结构体

在建立字典树之前,我们先定义每个字典树上节点的结构体变量。

基础代码:

struct Node
{
	int fail;
	int next[26];
	int count;
}trie[1000005];

指针代码:

struct Node
{
    node *fail;
    node *next[26];
    int count;
};

Node *root;

相信知道Trie Tree的小伙伴都发现了,里面多了一个fail,这个先放一放,等会再介绍。

2. 构造字典树

想必大家都知道了,这里废话就不多说了。

基础代码:

int tot = 0;

void build_trie(string str)
{
	int id = 0;
	int len = str.size();
	for(int i = 0;i < len;i++)
	{
		int now = str[i] - 'a';
		if(trie[id].next[now] == 0)
		{
			trie[id].next[now] = ++tot;
		}
		id = trie[id].next[now];
	}
	trie[id].count++;
}

指针代码:

Node *newnode;

void build_trie(char *s)
{
    Node *p = root;
    for(int i = 0;s[i];i++)
    {
        int x = s[i] - 'a';
        if(p -> next[x] == NULL)
        {
            newnode = (struct Node *)malloc(sizeof(struct Node));
            for(int j = 0;j < 26;j++)
			{
				newnode -> next[j] = 0;
			}
            newnode -> sum = 0;
			newnode -> fail = 0;
            p -> next[x] = newnode;
        }
        p = p -> next[x];
    }
    p -> sum++;
}

(二)建立FAIL

假如我们有5个单词she , he , her , shr , sn,并且有一个文本串fasherhr。

那么我们建立的字典树就会长成这样:(红色为单词结尾,用上面代码中count记录)

对于文本串,如果按照正常方法便会如此操作:

① f    ×      ② a    ×    ③ s(1)      ④ h(4)     ⑤ e(7)    ⑥ 返回h ⑦ h(2)    ......

但是时间复杂度就到了O(n\log{n}),并且和一般的Trie Tree就没什么区别了。

在kmp算法中我们构建了一个 next (对于一些编译器中的 bits/stdc++.h 头文件中拥有 next 语句,所以建议大家使用 nex 或 Next) 数组(查询表),通过查询表,在我们每次失配的情况下快速移动模式串,从而避免了大量的不必要的比较。

而在AC自动机也有这么一个东西叫做失配指针(就是前文建立的fail):(红色线就是失配指针)

 操作就会成为这样:

① f    ×        ② a     ×       ③ s(1)         ④ h(4) / h(2) ⑤ e(7) / e(5)     ⑥ r(8)    ......

时间复杂度马上就降到了O(n),由于每次向下求一层的失配指针,所以用bfs建立失配指针。

基础代码:

void build_fail()
{
	int id = 0;
	for(int i = 0;i < 26;i++)
	{
		int now = trie[id].next[i];
		if(now != 0)
		{
			q.push(now);
			trie[now].fail = id;
		}
	}
	while(!q.empty())
	{
		int f = q.front();
		q.pop();
		for(int i = 0;i < 26;i++)
	    {
	        int now = trie[f].next[i];
	        if(now == 0)
	        {
	        	trie[f].next[i] = trie[trie[f].fail].next[i];
	            continue;
	        }
	        trie[now].fail = trie[trie[f].fail].next[i];
	        q.push(now);
	    }
	}
}

指针代码:

Node *q[MAX];
Node *newnode;

int head,tail;

void build_fail()
{
    head = 0;
    tail = 1;
    q[head] = root;
    Node *p;
    Node *temp;
    while(head < tail)
    {
        temp = q[head++];
        for(int i = 0;i <= 25;i++)
        {
            if(temp -> next[i])
            {
                if(temp == root)
                {
                    temp -> next[i] -> fail = root;
                }
                else
                {
                    p = temp -> fail;
                    while(p)
                    {
                        if(p -> next[i])
                        {
                            temp -> next[i] -> fail = p -> next[i];
                            break;
                        }
                        p = p -> fail;
                    }
                    if(p == NULL)
					{
						temp -> next[i] -> fail = root;
					}
                }
                q[tail++] = temp -> next[i];
            }
        }
    }
}

(三)匹配过程

最后到了匹配,也就是求得答案。

基础代码:

int solve_AC(string art)
{
	int ans = 0;
	int id = 0;
	int len = art.size();
	for(int i = 0;i < len;i++)
	{
		int now = trie[id].next[art[i] - 'a'];
		while(now != 0 && trie[now].count != -1)
		{
			ans += trie[now].count;
			trie[now].count = -1;
			now = trie[now].fail; 
		}
		id = trie[id].next[art[i] - 'a'];
	}
	return ans;
}

指针代码:

int solve_AC(char *ch)
{
	int ans;
    Node *p = root;
    int len = strlen(ch);
    for(int i = 0;i < len;i++)
    {
        int x = ch[i] - 'a';
        while(!p -> next[x] && p != root)
		{
			p = p -> fail;
		}
        p = p -> next[x];
        if(!p)
		{
			p = root;
		}
        Node *temp = p;
        while(temp != root)
        {
			if(temp -> sum >= 0)
			{
				ans += temp -> sum;
				temp -> sum = -1;
			}
           	else
		   	{
		   		break;	
		   	}
            temp = temp -> fail;
        }
    }
    return ans;
}

四、例题

(一) 题目链接

AC自动机【简单版】       AC自动机【加强版】       AC自动机【二次加强版】

(二)讲解

对于简单版和刚才一样,代码如下:

(基础代码)

#include <bits/stdc++.h>
using namespace std;

struct Node
{
	int fail;
	int next[26];
	int count;
}trie[1000005];

int tot;
queue<int> q;

void build_trie(string str)
{
	int id = 0;
	int len = str.size();
	for(int i = 0;i < len;i++)
	{
		int now = str[i] - 'a';
		if(trie[id].next[now] == 0)
		{
			trie[id].next[now] = ++tot;
		}
		id = trie[id].next[now];
	}
	trie[id].count++;
}

void build_fail()
{
	int id = 0;
	for(int i = 0;i < 26;i++)
	{
		int now = trie[id].next[i];
		if(now != 0)
		{
			q.push(now);
			trie[now].fail = id;
		}
	}
	while(!q.empty())
	{
		int f = q.front();
		q.pop();
		for(int i = 0;i < 26;i++)
	    {
	        int now = trie[f].next[i];
	        if(now == 0)
	        {
	        	trie[f].next[i] = trie[trie[f].fail].next[i];
	            continue;
	        }
	        trie[now].fail = trie[trie[f].fail].next[i];
	        q.push(now);
	    }
	}
}

int solve_AC(string art)
{
	int ans = 0;
	int id = 0;
	int len = art.size();
	for(int i = 0;i < len;i++)
	{
		int now = trie[id].next[art[i] - 'a'];
		while(now != 0 && trie[now].count != -1)
		{
			ans += trie[now].count;
			trie[now].count = -1;
			now = trie[now].fail; 
		}
		id = trie[id].next[art[i] - 'a'];
	}
	return ans;
}

signed main()
{
	string str;
	int n;
	scanf("%d",&n);
	for(int i = 1;i <= n;i++)
	{
		cin >> str;
		build_trie(str);
	}
	build_fail();
	cin >> str;
	printf("%d\n",solve_AC(str));
	return 0;
}

(指针代码)

#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e7 + 5;
const int MAX = 10000000;

struct Node
{
    Node *next[26];
    Node *fail;
    int sum;
};
Node *root;
Node *q[MAX];
Node *newnode;

char pattern[maxn];
char key[70];
int head,tail;
int n;

void build_trie(char *s)
{
    Node *p = root;
    for(int i = 0;s[i];i++)
    {
        int x = s[i] - 'a';
        if(p -> next[x] == NULL)
        {
            newnode = (struct Node *)malloc(sizeof(struct Node));
            for(int j = 0;j < 26;j++)
			{
				newnode -> next[j] = 0;
			}
            newnode -> sum = 0;
			newnode -> fail = 0;
            p -> next[x] = newnode;
        }
        p = p -> next[x];
    }
    p -> sum++;
}

void build_fail()
{
    head = 0;
    tail = 1;
    q[head] = root;
    Node *p;
    Node *temp;
    while(head < tail)
    {
        temp = q[head++];
        for(int i = 0;i <= 25;i++)
        {
            if(temp -> next[i])
            {
                if(temp == root)
                {
                    temp -> next[i] -> fail = root;
                }
                else
                {
                    p = temp -> fail;
                    while(p)
                    {
                        if(p -> next[i])
                        {
                            temp -> next[i] -> fail = p -> next[i];
                            break;
                        }
                        p = p -> fail;
                    }
                    if(p == NULL)
					{
						temp -> next[i] -> fail = root;
					}
                }
                q[tail++] = temp -> next[i];
            }
        }
    }
}

int solve_AC(char *ch)
{
	int ans;
    Node *p = root;
    int len = strlen(ch);
    for(int i = 0;i < len;i++)
    {
        int x = ch[i] - 'a';
        while(!p -> next[x] && p != root)
		{
			p = p -> fail;
		}
        p = p -> next[x];
        if(!p)
		{
			p = root;
		}
        Node *temp = p;
        while(temp != root)
        {
			if(temp -> sum >= 0)
			{
				ans += temp -> sum;
				temp -> sum = -1;
			}
           	else
		   	{
		   		break;	
		   	}
            temp = temp -> fail;
        }
    }
    return ans;
}

int main()
{
    root = (struct Node *)malloc(sizeof(struct Node));
    for(int i = 0;i < 26;i++)
	{
		root -> next[i] = 0;
	}
    root -> fail = 0;
    root -> sum = 0;
    scanf("%d",&n);
    for(int i = 1;i <= n;i++)
    {
        cin >> key;
        build_trie(key);
    }
    cin >> pattern;
    build_fail();
    printf("%d\n",solve_AC(pattern));
    return 0;
}

注意:对于 Luogu 提交,两个代码均无法过,会报错(next 数组),但 Dev-C++ 应该不会保错,修改 next 数组即可。

而第二道题目只需改一下solve函数即可(其他函数微调):

#include <iostream>
#include <cstdio>
#include <queue>
#include <cstring>
#include <algorithm>
using namespace std;

struct Node
{
	int fail;
	int next[26];
	int num;
}trie[100005];

int tot;
int ans[155];
string a[155];

void clean(int x)
{
	memset(trie[x].next,0,sizeof(trie[x].next));
	trie[x].fail = 0;
	trie[x].num = 0;
}

void build_trie(int num,string str)
{
	int id = 0;
	int len = str.size();
	for(int i = 0;i < len;i++)
	{
		int now = str[i] - 'a';
		if(trie[id].next[now] == 0)
		{
			trie[id].next[now] = ++tot;
			clean(tot);
		}
		id = trie[id].next[now];
	}
	trie[id].num = num;
}

void build_fail()
{
	queue<int> q;
	int id = 0;
	for(int i = 0;i < 26;i++)
	{
		int now = trie[id].next[i];
		if(now != 0)
		{
			q.push(now);
			trie[now].fail = id;
		}
	}
	while(!q.empty())
	{
		int f = q.front();
		q.pop();
		for(int i = 0;i < 26;i++)
	    {
	        int now = trie[f].next[i];
	        if(now == 0)
	        {
	        	trie[f].next[i] = trie[trie[f].fail].next[i];
	            continue;
	        }
	        trie[now].fail = trie[trie[f].fail].next[i];
	        q.push(now);
	    }
	}
}

void solve_AC(string art)
{
	int id = 0;
	int len = art.size();
	for(int i = 0;i < len;i++)
	{
		id = trie[id].next[art[i] - 'a'];
		for(int j = id;j;j = trie[j].fail)
		{
			ans[trie[j].num]++;
		}
	}
}

int main()
{
	while(1)
	{
		int n;
		scanf("%d",&n);
		if(n == 0)
		{
			break;
		}
		tot = 0;
		clean(0);
		memset(ans,0,sizeof(ans));
		string str;
		for(int i = 1;i <= n;i++)
		{
			cin >> str;
			a[i] = str;
			build_trie(i,str);
		}
		build_fail();
		cin >> str;
		solve_AC(str);
		int maxn = -1;
		for(int i = 1;i <= n;i++)
		{
			maxn = max(maxn,ans[i]);
		}
		printf("%d\n",maxn);
		for(int i = 1;i <= n;i++)
		{
			if(ans[i] == maxn)
			{
				cout << a[i] << endl;
			}
		}
	}
	return 0;
}

这篇提交无问题,因为未使用 bits/stdc++.h 万能头文件。

五、习题练习

L语言

单词

阿狸的打字机

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值