AC自动机算法总结

199 阅读1分钟

携手创作,共同成长!这是我参与「掘金日新计划 · 8 月更文挑战」的第24天,点击查看活动详情


前言:

类似kmp,ac自动机也是个字符串匹配算法,不过kmp只能单模式串与文本串的匹配,ac自动机可以实现多模式串与文本串的匹配。ac自动机需要两个前置知识点:kmp和trie。但实际上用到的只有trie罢了,感觉和kmp关系并不大,至少我记的模板是这样的。

分三步:

  1. 构建trie树

    这部分就是普通的字典树插入操作,一模一样的。

    void Insert()
    {
            int now = 0, len = strlen(word);
            for(int i = 0; i < len; i++)
            {
                    int t = word[i];
                    if(!son[now][t])
                            son[now][t] = ++idx;
                    now = son[now][t];
            }
            cnt[now]++;
    }
    
  2. 构建fail指针及改造trie图

    这一部分很重要,首先明确fail指针含义,fail[i]表示以节点编号i为终止的字符串上可匹配的最长后缀,可匹配指与trie树上以节点编号j为终止的字符串完全一样,此时fail[i]=j。那如何确定节点i的fail指针呢,考虑其父节点的fail指针fafail,如果节点fafail具有一个和i一样的儿子,那i的fail就是fafail的那个儿子,如果fafail没有这样的一个儿子,那就去fafail的fail那里去找(其实是一个递归的过程),如果找到了根节点也没有这样的一个儿子,那就没办法了,fail[i]置为0即可,表示trie上不存在一个前缀能和该字符串某个后缀匹配。由于fail指针与父节点有关,所以选用bfs层次遍历比较合适,遍历到某节点时,其父节点的fail一定是更新正确的。

    与其它模板不大一样的地方在第16行和第20行,第16行语句用于构建fail指针,第20行用于改造为trie图,类似于并查集中的路径压缩。

    void GetFail()//bfs的过程 
    {
            queue<int> q;
            for(int i = 0; i < 128; i++)
                    if(son[0][i])
                            q.push(son[0][i]);
            while(q.size())
            {
                    int now = q.front();
                    q.pop();
                    for(int i = 0; i < 128; i++)
                    {
                            int to = son[now][i], fafail = fail[now];
                            if(to)
                            {
                                    fail[to] = son[fafail][i];//由于是个递归,这样写没问题 
                                    q.push(to);
                            }
                            else//递归的思想!
                                    son[now][i] = son[fafail][i];//路径压缩,直接到该去的地方 
                    }
            }
    }
    
  3. 文本串上查询

    这部分比较简单,改造为trie图后now = son[now][text[i]]会导致当前节点now在trie图上来回跳,而now到根节点之间的字符串就是文本串上的一段匹配串,第7行的for循环遍历该区间内所有可能存在单词的后缀。

    void Query()
    {
            int now = 0, len = text.size();
            for(int i = 0; i < len; i++)
            {
                    now = son[now][text[i]];//在trie图上乱跳 
                    for(int j = now; j; j = fail[j])//很关键,直接统计以text[i]结尾的所有可能模式串 
                            if(cnt[j])//如果标记有单词就统计一下 
                                    mp[text.substr(i-cnt[j]+1, cnt[j])]++;
            }
    }
    

模板:

以hdu3065为例,给出ac自动机模板。

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <queue>
#include <string>
#include <unordered_map>
using namespace std;
//巨坑,多组输入题目并没有告诉
//因为考虑到ascii可见部分还可能有空格,于是用getline读入,结果TLE,改用gets就ac了 
int son[50005][128], cnt[50005], idx;//cnt[i]记录i节点对应字符串长度 
int fail[50005];
char word[55], t[2000005];
string text, query[1005];
unordered_map<string, int> mp;
 
void insert()
{
	int now = 0, len = strlen(word);
	for(int i = 0; i < len; i++)
	{
		int t = word[i];
		if(!son[now][t])
			son[now][t] = ++idx;
		now = son[now][t];
	}
	cnt[now] = len;
}
 
void GetFail()
{
	queue<int> q;
	for(int i = 0; i < 128; i++)
		if(son[0][i])
			q.push(son[0][i]);
	while(q.size())
	{
		int now = q.front();
		q.pop();
		for(int i = 0; i < 128; i++)
		{
			int to = son[now][i], fafail = fail[now];
			if(to)
			{
				fail[to] = son[fafail][i];
				q.push(to);
			}
			else
				son[now][i] = son[fafail][i];
		}
	}
}
 
void Query()
{
	int now = 0, len = text.size();
	for(int i = 0; i < len; i++)
	{
		now = son[now][text[i]];//在trie图上乱跳 
		for(int j = now; j; j = fail[j])//很关键,直接统计以text[i]结尾的所有可能模式串 
			if(cnt[j])//如果标记有单词就统计一下 
				mp[text.substr(i-cnt[j]+1, cnt[j])]++;
	}
}
 
signed main()
{
	int n; 
	while(cin >> n)
	{
		for(int i = 0; i <= idx; i++)
		{
			fail[i] = cnt[i] = 0;
			for(int j = 0; j < 128; j++)
				son[i][j] = 0;
		}
//		memset(son, 0, sizeof son);
//		memset(fail, 0, sizeof fail);
//		memset(cnt, 0, sizeof cnt);
		idx = 0;
		mp.clear();
		getchar();
		for(int i = 1; i <= n; i++)
		{
			gets(word);
			query[i] = word;
			insert();
		} 
		GetFail();
		gets(t);
		text = t;
		Query();
		for(int i = 1; i <= n; i++)
		{
			if(mp[query[i]])
				printf("%s: %d\n", query[i].c_str(), mp[query[i]]);
		} 
	}
    return 0;
}