AC文本匹配算法

85 阅读1分钟

#原理

前缀树和fail函数构造

树结构

package com.example.demo.ac;

import lombok.Data;

import java.util.*;

@Data
public class TreeNode {
    private String text;//节点值
    private boolean endFlag=false;//是否是词语结束标识
    private List<TreeNode> childen=new ArrayList<>();//节点的子节点
    private TreeNode failNode;//节点失败跳转的子节点
    private TreeNode parentNode;//父节点
    private boolean rootFlag=false;//是否是根节点
    //获取当前节点子节点是否匹配
    public TreeNode getChildrenByText(String text){
        if(this.childen!=null&&this.childen.size()>0){
            for(TreeNode treeNode:this.childen){
                if(treeNode.getText().equals(text)){
                    return treeNode;
                }
            }
        }
        return null;
    }
    //添加词语到树
    public static void insert(String[] words,TreeNode root){
        TreeNode treeNode=root;
        int i=0;
        for(String text:words){
            TreeNode chidren=treeNode.getChildrenByText(text);
            if(chidren==null){
                boolean endFlag=i==words.length-1;
                TreeNode newTreeNode=new TreeNode();
                newTreeNode.setText(text);
                newTreeNode.setEndFlag(endFlag);
                newTreeNode.setParentNode(treeNode);
                treeNode.getChilden().add(newTreeNode);
                treeNode=newTreeNode;
            }else{
                treeNode=chidren;
            }
            i++;
        }
    }
    //构造失败跳转节点
    public static void addFailNode(TreeNode current){
        Queue<TreeNode> queue = new LinkedList<>();
        TreeNode parentFailNode=current.getParentNode().getFailNode();
        queue.offer(parentFailNode);
        while(true){
            TreeNode mayFailNode=queue.poll();
            if(mayFailNode==null){
                break;
            }
            List<TreeNode> pfchildrenNode=mayFailNode.getChilden();
            if(pfchildrenNode!=null&&pfchildrenNode.size()>0){
                for(TreeNode treeNode:pfchildrenNode){
                    if(current.getText().equals(treeNode.getText())){
                        current.setFailNode(treeNode);
                        break;
                    }else{
                        if(mayFailNode.isRootFlag()==true){
                            current.setFailNode(mayFailNode);
                        }
                    }
                }
            }
            queue.offer(mayFailNode.getFailNode());
        }
    }
    //匹配
    public static List<Integer[]> matchWords(String words,TreeNode root){
        List<Integer[]> result=new ArrayList<>();
        int startPoint=0;
        int endPoint=0;
        List<String> wordsList= Arrays.asList(words.split(""));
        TreeNode current=root;
        for(String value:wordsList){
            TreeNode matchNode=TreeNode.failedAfterDeal(value,current);
            if(matchNode==null){
                current=root;
                endPoint++;
                startPoint=endPoint;
            }else{
                if(matchNode.isEndFlag()==true){
                    Integer[] success={startPoint,endPoint};
                    result.add(success);
                }
                endPoint++;
                current=matchNode;
            }
        }
        return result;
    }
    public static TreeNode failedAfterDeal(String value,TreeNode current){
        while(true){
            if(current==null){
                return null;
            }
            TreeNode treeNode=current.getChildrenByText(value);
            if(treeNode!=null){
                return treeNode;
            }else{
                current=current.getFailNode();
            }
        }
    }
}

测试类

package com.example.demo.ac;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.Queue;

public class AcTest {
    public static void main(String[] args) {
        List<String> str=new ArrayList<String>(){{
            add("中国");
            add("中华共和国");
            add("禁用动是");
            add("禁用词");
            add("禁用动作");
        }};
        TreeNode root=new TreeNode();
        root.setText("root");
        root.setRootFlag(true);
        for(String value:str){
            String[] arr=value.split("");
            TreeNode.insert(arr,root);
        }
        List<TreeNode> rootChildren=root.getChilden();
        rootChildren.forEach(i->i.setFailNode(root));
        //遍历
        Queue<TreeNode> preAddFail=new LinkedList<>();
        preAddFail.offer(root);
        while(true){
            TreeNode current=preAddFail.poll();
            if(current==null){
                break;
            }
            List<TreeNode> children=current.getChilden();
            children.forEach(i->preAddFail.offer(i));
            if(current.isRootFlag()==true||current.getFailNode()!=null){
                continue;
            }
            TreeNode.addFailNode(current);
        }
        String data="正的加的四号中华共和国不发空间发禁用动是零距离几率记禁用动作录";
        List<Integer[]> result=TreeNode.matchWords(data,root);
        for(Integer[] arr:result){
            System.out.println(arr[0]+":"+arr[1]);
            System.out.println(data.substring(arr[0],arr[1]+1));
        }
    }
}