ahocorasick算法

270 阅读9分钟

零、参考资料

baijiahao.baidu.com/s?id=163301…

blog.csdn.net/qq_43800119…

zhuanlan.zhihu.com/p/161700510

blog.csdn.net/qq_52965253…

附:我的原文在语雀文档:www.yuque.com/zheng-gmkoc… 《ahocorasick算法》

一、应用场景

用于一段文本匹配多个关键词,采用最原始的办法,可以写出以下的代码,时间复杂度为O(N*M),其中N为text文本长度,M为关键字个数

def originMatch(text, keywords):
    for i in range(len(keywords)):
        keyword = keywords[i]
        if keyword in text:
            print(f"关键词{keyword}匹配成功")

text = "abccab"
keywords = ["ab", "bc", "ca"]
originMatch(text, keywords)

而ahocorasick算法将这类多关键词匹配文本的算法时间复杂度降到了O(N),即:只与文本长度有关

二、介绍

Aho-Corasick算法是一种用于多模式字符串匹配的算法,它可以高效地在一段文本中同时查找多个关键词(模式)。该算法由 Alfred V. AhoMargaret J. Corasick 于 1975 年提出,是一种经典的字符串匹配算法之一。

Aho-Corasick算法的主要思想是构建一个自动机(trie树)来存储所有的模式,同时通过添加额外的指针和状态转移来实现高效的匹配。下面是Aho-Corasick算法的主要步骤:

  1. 构建Trie树: 将所有的模式串插入到一个Trie树中。Trie树是一种树状数据结构,每个节点表示一个字符,路径从根节点到叶子节点表示一个模式串。
  2. 添加失败指针: 为Trie树中的每个节点添加一个失败指针,该指针指向在当前节点失败时应该跳转的节点。这个过程使用广度优先搜索(BFS) 来实现。
  3. 匹配过程: 从文本的开头开始,沿着Trie树逐字符匹配。如果无法匹配当前字符,根据失败指针跳转到另一个节点,并继续匹配,直到匹配成功或者到达文本末尾。

这样,Aho-Corasick算法可以高效地在一次遍历中找到所有的模式串,并且具有线性的时间复杂度

三、算法流程/原理

3.1 Trie树的简介

Trie(也称为字典树、前缀树或字母树) 是一种树状的数据结构,用于存储关联数组,其中键通常是字符串。Trie树具有以下主要特点:

  1. 树状结构: Trie 是一种树状结构,每个节点代表一个字符或者一个字符的部分。
  2. 路径表示键: 从树的根到任意节点的路径上的字符连接起来表示一个键。
  3. 前缀匹配: Trie树对于具有相同前缀的键具有共享的节点。这使得Trie树非常适合进行字符串搜索和前缀匹配。
  4. 高效插入和查找: Trie树的结构允许在O(L)的时间复杂度内插入和查找,其中L是键的长度。
  5. 无需比较: 与二叉查找树不同,Trie不需要在节点之间进行比较,因为每个节点都代表一个确定的字符。

Trie树在处理字符串相关问题时非常有用,例如搜索引擎中的自动完成、拼写检查,以及字典的实现等。对于大量字符串的存储和检索,Trie树可以提供高效的性能。

3.2 Trie树构建

假设有以下关键词patterns = ["he", "she", "his", "her", "hello", "help", "sheep", "shy"],构建出以下trie树。

实现的python代码如下

class TrieNode:
    def __init__(self):
        # 当前节点的字符
        self.char = ''
        # 子结点
        self.children = {}
        # 是否为单词结尾
        self.is_end_of_word = False
        self.failure_link = None

def build_trie(patterns):
    # 创建根节点
    root = TrieNode()
    # 开始模式匹配创建匹配树
    for pattern in patterns:
        # 每一此新的字符串开始的时候,都从根节点开始
        node = root
        # 遍历字符
        for char in pattern:
            # 如果当前的字符并没有存在当前节点的子结点中,就需要重建一个节点作为子结点
            if char not in node.children:
                node.children[char] = TrieNode()
                node.children[char].char = char
            # 往下一层继续探索
            node = node.children[char]
        # 最底层的node要有单词末尾结束标记
        node.is_end_of_word = True
    return root

patterns = ["he", "she", "his", "her", "hello", "help", "sheep", "shy"]

root = build_trie(patterns)

3.3 添加失败指针

逻辑

  • 根节点: 失败指向节点为null。

  • 第一层节点: 失败指向节点为根节点

  • 其余非根节点: 对于状态匹配失败的节点a

    • 如果其父节点的失败节点可以根据该节点a的跳转字符成功转移状态到另一节点b,那么就将失败节点a指向该节点b;

    • 如果其父节点的失败节点不能根据该节点的跳转字符转移状态到另一节点,那么就将失败节点a将检查其父节点的父节点的失败节点是否满足上述条件;

    • 依次递推,如果回溯到根节点还未找到,那就指向将失败节点指向根节点。

代码如下

class TrieNode:
    def __init__(self):
        self.children = {}
        self.is_end_of_word = False
        self.failure_link = None

def build_trie(patterns):
    root = TrieNode()
    for pattern in patterns:
        node = root
        for char in pattern:
            if char not in node.children:
                node.children[char] = TrieNode()
            node = node.children[char]
        node.is_end_of_word = True
    return root

def add_failure_links(root):
    queue = []
    # 每个根节点下的节点一旦失败,就会返回到根节点重新查询
    for child in root.children.values():
        queue.append(child)
        child.failure_link = root

    # 层序遍历
    while queue:
        current_node = queue.pop(0)
        # 遍历当前节点
        for char, child in current_node.children.items():
            # 将当前节点的子结点存入队列
            queue.append(child)
            # 拿出失败时的节点
            failure_node = current_node.failure_link
            # 对于状态匹配失败的节点a,如果其父节点的失败节点可以根据该节点a的跳转字符成功转移状态到另一节点b,
            # 那么就将失败节点a指向该节点b;如果其父节点的失败节点不能根据该节点的跳转字符转移状态到另一节点,
            # 那么就将失败节点a将检查其父节点的父节点的失败节点是否满足上述条件;
            # 依次递推,如果回溯到根节点还未找到,那就指向将失败节点指向根节点。
            while failure_node and char not in failure_node.children:
                failure_node = failure_node.failure_link
            # 三元运算符相当于 a ? b : c
            child.failure_link = failure_node.children[char] if failure_node.children[char] else root

以patterns=["ab", "bc", "ca"] 为例,过程如下。

构建trie树

构建第一层节点的失败指针

构建第一层节点a的子结点b的失败指针

构建第一层节点b的子结点c的失败指针

构建第一层节点c的子结点a的失败指针

3.4 匹配逻辑

一直往子结点匹配,如果当前节点的子结点没有符合改字符的,就去找失败指针指向的节点,依此类推。

def search(text, patterns):
    root = build_trie(patterns)
    add_failure_links(root)

    current_node = root
    matches = []

    for i, char in enumerate(text):
        # 如果当前的节点存在,且当前字符匹配失败,就去找失败节点
        while current_node and char not in current_node.children:
            current_node = current_node.failure_link

        # 如果当前的节点为空,其实就意味着找到了根节点的失败节点了,
        # 此时表明查询失败重新从根节点开始查询
        if not current_node:
            current_node = root
            continue
        
        # 如果char存在于当前的节点的子结点中,说明成功匹配
        # 递进去下一个节点
        current_node = current_node.children[char]
        # 如果此时匹配到的值已经是匹配关键词的末尾了,说明匹配中了一个字串
        # 记录下来
        if current_node.is_end_of_word:
            matches.append((current_node.word, (i - len(current_node.word) + 1, i)))
            if current_node.failure_link.is_end_of_word:
                tempNode = current_node.failure_link
                matches.append((tempNode.word, (i - len(tempNode.word) + 1, i)))
    return matches

四、完整代码

class TrieNode:
    def __init__(self):
        # 当前节点的子结点
        self.children = {}
        # 当前节点是否为结尾
        self.is_end_of_word = False
        # 当前节点查询失败后转到下一个节点
        self.failure_link = None
        # 当前节点的字符
        self.char = ""
        # 当前节点的字符叠加,用于记录叶子节点的对应字符串
        self.sumChart = ""
        # 叶子节点的单词
        self.word = ""

def build_trie(patterns):
    root = TrieNode()
    for pattern in patterns:
        currentNode = root
        for char in pattern:
            # 如果这个字符并没有出现在当前节点的子结点中,那么就要创造一个这样的节点作为子结点
            if char not in currentNode.children:
                childNode = TrieNode()
                # 赋值字符
                childNode.char = char
                # 子结点的累加字符为当前节点的累加字符并添加当前的字符
                childNode.sumChart = currentNode.sumChart + char
            # 当前子结点已经存在,直接赋值即可
            else:
                childNode = currentNode.children[char]
            currentNode.children[char] = childNode
            # 然后继续遍历子结点
            currentNode = currentNode.children[char]
        currentNode.word = currentNode.sumChart
        currentNode.is_end_of_word = True
    return root

def add_failure_links(root):
    queue = []
    # 每个根节点下的节点一旦失败,就会返回到根节点重新查询
    for child in root.children.values():
        queue.append(child)
        child.failure_link = root

    # 层序遍历
    while queue:
        current_node = queue.pop(0)
        # 遍历当前节点
        for char, childNode in current_node.children.items():
            # 将当前节点的子结点存入队列
            queue.append(childNode)
            # 拿出失败时的节点
            failure_node = current_node.failure_link
            '''
            对于状态匹配失败的节点a,
            如果其父节点的失败节点可以根据该节点a的跳转字符成功转移状态到另一节点b,那么就将失败节点a指向该节点b;
            如果其父节点的失败节点不能根据该节点的跳转字符转移状态到另一节点,那么就将失败节点a将检查其父节点的父节点的失败节点是否满足上述条件;
            依次递推,如果回溯到根节点还未找到,那就指向将失败节点指向根节点。
            '''
            while failure_node and char not in failure_node.children:
                failure_node = failure_node.failure_link
            
            # 如果当前的failure_node存在并且其以char为值的子结点也存在
            # 当前的childNode节点的失败指针就指向failure_node节点的失败指针指向的存有该char值的子结点
            if failure_node and failure_node.children[char]:
                childNode.failure_link = failure_node.children[char]
                # 这里需要注意,如果当前的childNode节点的失败时指向的节点是叶节点,其实也意味着
                # 查找到这个子结点的时候,已经匹配到了一个字串了
                if childNode.failure_link.is_end_of_word:
                    # 当childNode为非根节点的时候需要赋值word为其失败后指向节点的word
                    if not childNode.is_end_of_word:
                        childNode.is_end_of_word = True
                        # 既然当前的节点表示已经匹配到了一个字串,那么相应的该节点应该也是有一个单词的
                        childNode.word = childNode.failure_link.word
            else:
                childNode.failure_link = root

def search(text, patterns):
    root = build_trie(patterns)
    add_failure_links(root)

    current_node = root
    matches = []

    for i, char in enumerate(text):
        # 如果当前的节点存在,且当前字符匹配失败,就去找失败节点
        while current_node and char not in current_node.children:
            current_node = current_node.failure_link

        # 如果当前的节点为空,其实就意味着找到了根节点的失败节点了,
        # 此时表明查询失败重新从根节点开始查询
        if not current_node:
            current_node = root
            continue
        
        # 如果char存在于当前的节点的子结点中,说明成功匹配
        # 递进去下一个节点
        current_node = current_node.children[char]
        # 如果此时匹配到的值已经是匹配关键词的末尾了,说明匹配中了一个字串
        # 记录下来
        if current_node.is_end_of_word:
            matches.append((current_node.word, (i - len(current_node.word) + 1, i)))
            if current_node.failure_link.is_end_of_word:
                tempNode = current_node.failure_link
                matches.append((tempNode.word, (i - len(tempNode.word) + 1, i)))
    # 数据处理
    data = list(set(matches))
    result = []
    for item in data:
        tempObj = {
            "patternString": item[0],
            "matchStartIndex": item[1][0],
            "matchEndIndex": item[1][1]
        }
        result.append(tempObj)
    # 按matchEndIndex排序
    sortedResult = sorted(result, key=lambda x: (x['matchEndIndex'], x['matchStartIndex']))
    return sortedResult

# Example
text = "abccab"
patterns = ["ab", "bc", "ca", "ccab"]
result = search(text, patterns)
for item in result:
    print(item)

五、python中的pyahocorasick使用

import ahocorasick

def search_with_pyahocorasick(text, patterns):
    A = ahocorasick.Automaton()
    for pattern in patterns:
        A.add_word(pattern, pattern)

    A.make_automaton()

    matches = []
    for item in A.iter(text):
        end_index = item[0]
        pattern = item[1]
        start_index = end_index - len(pattern) + 1
        matches.append((pattern, (start_index, end_index)))
    data = list(set(matches))
    result = []
    for item in data:
        tempObj = {
            "patternString": item[0],
            "matchStartIndex": item[1][0],
            "matchEndIndex": item[1][1]
        }
        result.append(tempObj)
    sortedResult = sorted(result, key=lambda x: (x['matchEndIndex'], x['matchStartIndex']))
    return sortedResult

# Example
text = "abccab"
patterns = ["ab", "bc", "ca", "ccab"]
result = search_with_pyahocorasick(text, patterns)
for item in result:
    print(item)