前缀树

111 阅读4分钟

前缀树

概念

  • 单个字符串中,字符从前到后的加到一棵多叉树上
  • 字符放在路上,节点上有专属的数据项,常见的是pass(通过这个字符的数量)和end(以这个字符结尾的数量)值
  • 所有样本都这样添加,如果没有路就新建,如有路就复用
  • 沿途节点的pass值增加1,每个字符串结束时来到的节点end值增加1

前缀树可以完成前缀相关的查询(前缀模糊查询),哈希表可以完成精确查询

设计

设计一种结构,用户可以进行如下操作

  1. void insert(String str) 添加某个字符串,可以重复添加
  2. int search(String str) 查询某个字符串在结构中还有几个
  3. void delete(String str) 删掉某个字符串,可以重复删除
  4. int prefixNumber(String str) 查询有多少个字符串,是以str做前缀的

代码实现

固定数组实现

public class TrieTreeArr {

    public static class Node {
        public int pass;
        public int end;
        public Node[] children;

        public Node() {
            pass = 0;
            end = 0;
            // 0    a
            // 1    b
            // ..   ..
            // 25   z
            // children[i] == null i方向的路不存在,children[i] != null i方向的路存在
            children = new Node[26]; // 假设字符都是小写的英文字母
        }
    }

    public static class Trie {
        private Node root;

        public Trie() {
            root = new Node();
        }

        public void insert(String word) {
            Node cur = root;
            // 插入数据,通过根节点的数量加1
            cur.pass++;
            for (char c : word.toCharArray()) { // 遍历所有字符
                int index = c - 'a'; // 该字符对应走哪条路,对应的下标
                if (cur.children[index] == null) {
                    cur.children[index] = new Node();
                }
                cur = cur.children[index];
                // 插入字符,通过对应节点的数量加1
                cur.pass++;
            }
            cur.end++;
        }

        public void delete(String word) {
            if (search(word) == 0) {
                return;
            }
            Node cur = root;
            // 删除数据,通过根节点的数量减1
            cur.pass--;
            for (char c : word.toCharArray()) { // 遍历所有字符
                int index = c - 'a'; // 该字符对应走哪条路,对应的下标
                // 删除字符,通过对应节点的数量减1
                if (--cur.children[index].pass == 0) { // 通过对应节点的数量为0,删除该节点,并停止遍历
                    cur.children[index] = null;
                    return;
                }
                // 在insert方法中可以先更新cur,删除数据时不行,因为可能需要删除节点,只能通过父节点删除子节点
                cur = cur.children[index];
            }
            cur.end--;
        }

        // word这个单词之前加入过几次
        public int search(String word) {
            Node node = prefixNode(word);
            return node == null ? 0 : node.end;
        }

        // 所有加入的字符串中,有几个是以pre这个字符串作为前缀的
        public int prefixNumber(String pre) {
            Node node = prefixNode(pre);
            return node == null ? 0 : node.pass;
        }

        // 查找以pre为前缀的节点
        private Node prefixNode(String pre) {
            Node cur = root;
            for (char c : pre.toCharArray()) {
                int index = c - 'a';
                cur = cur.children[index];
                if (cur == null) { // cur为null,说明没有以pre为前缀的节点
                    return null;
                }
            }
            return cur;
        }
    }
}

哈希表实现

public class TrieTreeHash {

    public static class Node {
        public int pass;
        public int end;
        public Map<Integer, Node> children; // char可以用int表示

        public Node() {
            pass = 0;
            end = 0;
            children = new HashMap<>();
        }
    }

    public static class Trie {
        private Node root;

        public Trie() {
            root = new Node();
        }

        public void insert(String word) {
            Node cur = root;
            // 插入数据,通过根节点的数量加1
            cur.pass++;
            for (int key : word.toCharArray()) { // 遍历所有字符,该字符对应走哪条路,即对应的key
                if (!cur.children.containsKey(key)) {
                    cur.children.put(key, new Node());
                }
                cur = cur.children.get(key);
                // 插入字符,通过对应节点的数量加1
                cur.pass++;
            }
            cur.end++;
        }

        public void delete(String word) {
            if (search(word) == 0) {
                return;
            }
            Node cur = root;
            // 删除数据,通过根节点的数量减1
            cur.pass--;
            for (int key : word.toCharArray()) { // 遍历所有字符,该字符对应走哪条路,即对应的key
                // 删除字符,通过对应节点的数量减1
                if (--cur.children.get(key).pass == 0) { // 通过对应节点的数量为0,删除该节点,并停止遍历
                    cur.children.remove(key);
                    return;
                }
                // 在insert方法中可以先更新cur,删除数据时不行,因为可能需要删除节点,只能通过父节点删除子节点
                cur = cur.children.get(key);
            }
            cur.end--;
        }

        // word这个单词之前加入过几次
        public int search(String word) {
            Node node = prefixNode(word);
            return node == null ? 0 : node.end;
        }

        // 所有加入的字符串中,有几个是以pre这个字符串作为前缀的
        public int prefixNumber(String pre) {
            Node node = prefixNode(pre);
            return node == null ? 0 : node.pass;
        }

        private Node prefixNode(String pre) {
            Node cur = root;
            for (int key : pre.toCharArray()) {
                if (!cur.children.containsKey(key)) {
                    return null;
                }
                cur = cur.children.get(key);
            }
            return cur;
        }
    }
}

测试

public class TrieTreeTest {
    // 对数器
    public static class Right {
        private Map<String, Integer> box;

        public Right() {
            box = new HashMap<>();
        }

        public void insert(String word) {
            int value = 1;
            if (box.containsKey(word)) {
                value += box.get(word);
            }
            box.put(word, value);
        }

        public void delete(String word) {
            if (!box.containsKey(word)) {
                return;
            }
            Integer value = box.get(word);
            if (value == 1) {
                box.remove(word);
            } else {
                box.put(word, value - 1);
            }
        }

        public int search(String word) {
            if (box.containsKey(word)) {
                return box.get(word);
            }
            return 0;
        }

        public int prefixNumber(String pre) {
            int res = 0;
            for (Map.Entry<String, Integer> entry : box.entrySet()) {
                String key = entry.getKey();
                if (key.startsWith(pre)) {
                    res += entry.getValue();
                }
            }
            return res;
        }
    }

    // 生成随机字符串,用于测试
    public static String generateRandomString(int strLen) {
        Random random = new Random();
        char[] ans = new char[random.nextInt(strLen) + 1];
        for (int i = 0; i < ans.length; i++) {
            int value = random.nextInt(6);
            ans[i] = (char) (97 + value);
        }
        return String.valueOf(ans);
    }

    // 生成随机字符串数组,用于测试
    public static String[] generateRandomStringArray(int arrLen, int strLen) {
        Random random = new Random();
        String[] ans = new String[random.nextInt(strLen) + 1];
        for (int i = 0; i < ans.length; i++) {
            ans[i] = generateRandomString(strLen);
        }
        return ans;
    }


    public static void main(String[] args) {
        int arrLen = 200;
        int strLen = 20;
        int testTimes = 100000;
        for (int i = 0; i < testTimes; i++) {
            String[] arr = generateRandomStringArray(arrLen, strLen);
            TrieTreeArr.Trie  trie1 = new TrieTreeArr.Trie ();
            TrieTreeHash.Trie  trie2 = new TrieTreeHash.Trie ();
            Right right = new Right();
            for (String s : arr) {
                double decide = Math.random();
                if (decide < 0.25) {
                    trie1.insert(s);
                    trie2.insert(s);
                    right.insert(s);
                } else if (decide < 0.5) {
                    trie1.delete(s);
                    trie2.delete(s);
                    right.delete(s);
                } else if (decide < 0.75) {
                    int ans1 = trie1.search(s);
                    int ans2 = trie2.search(s);
                    int ans3 = right.search(s);
                    if (ans1 != ans2 || ans2 != ans3) {
                        System.out.println("Oops!");
                        trie1.search(s);
                        trie2.search(s);
                    }
                } else {
                    int ans1 = trie1.prefixNumber(s);
                    int ans2 = trie2.prefixNumber(s);
                    int ans3 = right.prefixNumber(s);
                    if (ans1 != ans2 || ans2 != ans3) {
                        System.out.println("Oops!");
                        trie1.prefixNumber(s);
                        trie2.prefixNumber(s);
                    }
                }
            }
        }
        System.out.println("finish!");

    }
}