算法数据结构:Trie 树

1,545 阅读5分钟

1、什么是 Trie 树

Trie 树,也叫字典树。顾名思义,它就是一个树形结构,是一种专门处理字符串匹配的字典数据结构,用来解决在一组字符串集合中快速找到某个字符串的问题。

当然,这样一个问题可以有多种解决方法,比如散列表、红黑树等。但是 Trie 树在这个问题的解决上,有它的优点。Trie 树能解决的问题不限于此,后面会介绍。

举个例子,我们有 6 个字符串,分别是:how、hi、her、hello、so、see。我们希望在里面多次查找某个字符串是否存在。如果每次查找,都是拿要查找的字符串跟着 6 个字符串依次进行字符串匹配,那效率就比较低,有没有更高效的方法呢?

这个时候,我们就可以先对这 6 个字符串做一下预处理,构建成 Trie 树的结构,之后每次查找,都是在 Trie 树中进行匹配查找。Trie 树的本质,就是利用字符串之间的公共前缀,将重复的前缀合并在一起。过程如下:

其中,根节点不包含任何信息。每个节点表示一个字符串中的字符,从根节点到红色节点的一条路径表示一个字符串(注意:红色并不都是叶子结点)。

当我们在 Trie 树种查找一个字符串 her 的时候,那我们将要查找的字符串分割成单个的字符 h e r,然后从 Trie 树的根节点开始匹配。过程如图所示,绿色的路径就是在 Trie 树中匹配的路径。

如果我们要查找的是字符串 he 呢?我们还用上面同样的方法,如图所示,从根节点开始,沿着某条路径俩匹配,绿色的路径,是字符串 he 匹配的路径。但是,路径的最后一个节点 e 不是红色的。也就是说,he 是某个字符串的前缀子串,但并不能完全匹配任何字符串。

2、如何实现一棵 Trie 树?

2.1、分析

Trie 树主要有两个操作,第一个是将字符串集合构造成 Trie 树,换句话来说,就是一个将字符串插入到 Trie 树的过程。第二个是在 Trie 树中查询一个字符串。

前面可以看出,Trie 树是一个多叉树。对于多叉树来说,如何存储一个结点的所有子结点的指针呢?假设字符串中只有 a 到 z 这 26 个小写字母,如图所示,这里借助散列表的思想,我们通过一个下标与字符一一映射的数组,来存储子结点的指针。我们可以根据哈希函数(字符的 ASCII 码 - 'a' 的 ASCII 码)知道,数组中下标为 0 位置,存储指向子结点 a 的指针,下标为 1 的位置存储指向子结点为 b 的指针,以此类推,下标为 25 的位置,存储的是指向子结点 z 的指针。如果某个字符的子结点不存在,我们就在队形的下标位置存储 null。

2.2、实现

  1. 单个字符串中,字符从前到后的加到一棵多叉树上
  2. 字符放在路上,节点上有专属的数据项(常见的是 pass 和 end 值)
  3. 所有样本都这样添加,如果没有路就新建,如有路就复用
  4. 沿途节点的 pass 值增加 1,每个字符串结束时来到的节点 end 值增加 1
2.2.1、Trie 树测试工具类
import java.util.HashMap;

/**
 * trie 树测试工具
 *
 * @author yangjian
 */
public class TrieUtils {

    public static Right right;

    /**
     * 生成随机字符数组
     *
     * @param arrLen 数组长度
     * @param strLen 字符串长度
     * @return 字符数组
     */
    public static String[] generateRandomStringArray(int arrLen, int strLen) {
        String[] ans = new String[(int) (Math.random() * arrLen) + 1];
        for (int i = 0; i < ans.length; i++) {
            ans[i] = generateRandomString(strLen);
        }
        return ans;
    }

    /**
     * 生成随机字符串
     *
     * @param strLen 字符串长度
     * @return 字符串
     */
    public static String generateRandomString(int strLen) {
        char[] ans = new char[(int) (Math.random() * strLen) + 1];
        for (int i = 0; i < ans.length; i++) {
            int value = (int) (Math.random() * 6);
            ans[i] = (char) (97 + value);
        }
        return String.valueOf(ans);
    }

    public static void newRight() {
        right = new Right();
    }

    public static void insert(String word) {
        right.insert(word);
    }

    public static void delete(String word) {
        right.delete(word);
    }

    public static int search(String word) {
        return right.search(word);
    }

    public static int prefixNumber(String prefix) {
        return right.prefixNumber(prefix);
    }

    public static class Right {

        private final HashMap<String, Integer> root;

        public Right() {
            root = new HashMap<>();
        }

        public void insert(String word) {
            if (!root.containsKey(word)) {
                root.put(word, 1);
            } else {
                root.put(word, root.get(word) + 1);
            }
        }

        public void delete(String word) {
            if (root.containsKey(word)) {
                if (root.get(word) == 1) {
                    root.remove(word);
                } else {
                    root.put(word, root.get(word) - 1);
                }
            }
        }

        public int search(String word) {
            return root.getOrDefault(word, 0);
        }

        public int prefixNumber(String pre) {
            int count = 0;
            for (String cur : root.keySet()) {
                if (cur.startsWith(pre)) {
                    count += root.get(cur);
                }
            }
            return count;
        }
    }
}

2.2.2、Trie 树代码一
import java.util.Objects;

/**
 * @description: Trie 树数组实现,当字符串只包含 a-z 26 个小写字符
 * @author: erlang
 * @since: 2020-09-20 10:56
 */
public class TrieTree {

    public static class Node {
        /**
         * 记录通过次数
         */
        public int pass;

        /**
         * 记录结尾次数
         */
        public int end;

        /**
         * 记录字符
         */
        public Node[] nexts;

        public Node() {
            this.pass = 0;
            this.end = 0;
            // 只有 a-z 26 个小写字母
            // 0 a, 1 b, ..., z 25
            nexts = new Node[26];
        }
    }

    /**
     * 根节点
     */
    private Node root;

    public TrieTree() {
        root = new Node();
    }

    /**
     * 计算字符的索引位置
     *
     * @param ch 字符
     * @return 索引位置
     */
    public int hash(char ch) {
        return ch - 'a';
    }

    /**
     * 添加某个字符串,可以重复添加,每次加一
     *
     * @param word 添加的字符串
     */
    public void insert(String word) {
        if (Objects.isNull(word)) {
            return;
        }
        char[] chars = word.toCharArray();
        if (chars.length == 0) {
            return;
        }
        Node node = root;

        node.pass++;
        for (char ch : chars) {
            // 计算字符的索引位置
            int index = hash(ch);
            if (node.nexts[index] == null) {
                node.nexts[index] = new Node();
            }
            node = node.nexts[index];
            node.pass++;
        }
        node.end++;
    }

    /**
     * 查询某个字符串在结构中的个数
     *
     * @param word 查询的字符串
     * @return 个数
     */
    public int search(String word) {
        if (Objects.isNull(word)) {
            return 0;
        }

        Node node = root;

        for (char ch : word.toCharArray()) {
            // 计算字符的索引位置
            int index = hash(ch);
            if (node.nexts[index] == null) {
                return 0;
            }
            node = node.nexts[index];
        }
        return node.end;
    }

    /**
     * 删除某个字符串,可以重复删除,每次减一
     *
     * @param word 删除的字符串
     */
    public void delete(String word) {
        // 先判断是否存在
        if (search(word) == 0) {
            return;
        }

        Node node = root;
        node.pass--;
        for (char ch : word.toCharArray()) {
            int index = hash(ch);
            if (--node.nexts[index].pass == 0) {
                // 如果当前节点的 pass = 0,则说明已经没有字符串的前缀是当前节点了
                // 后面的所有节点都可以直接删除,无需遍历了
                node.nexts[index] = null;
                return;
            }
            node = node.nexts[index];
        }
        node.end--;
    }

    /**
     * 查询有多少个字符串,是以 prefix 为前缀的
     *
     * @param prefix 查询的前缀字符串
     * @return 个数
     */
    public int prefixNumber(String prefix) {
        if (Objects.isNull(prefix)) {
            return 0;
        }

        Node node = root;
        for (char ch : prefix.toCharArray()) {
            int index = hash(ch);

            if (node.nexts[index] == null) {
                return 0;
            }

            node = node.nexts[index];
        }
        return node.pass;
    }

    public static void main(String[] args) {
        int arrLen = 100;
        int strLen = 20;
        int testTimes = 100000;
        for (int i = 0; i < testTimes; i++) {
            String[] arr = TrieUtils.generateRandomStringArray(arrLen, strLen);
            TrieTree trie = new TrieTree();
            TrieUtils.newRight();
            for (String s : arr) {
                double decide = Math.random();
                if (decide < 0.25) {
                    trie.insert(s);
                    TrieUtils.insert(s);
                } else if (decide < 0.5) {
                    trie.delete(s);
                    TrieUtils.delete(s);
                } else if (decide < 0.75) {
                    int ans1 = trie.search(s);
                    int ans3 = TrieUtils.search(s);
                    if (ans1 != ans3) {
                        System.out.println("Oops13!");
                    }
                } else {
                    int ans1 = trie.prefixNumber(s);
                    int ans3 = TrieUtils.prefixNumber(s);
                    if (ans1 != ans3) {
                        System.out.println("Oops!");
                    }
                }
            }
        }
        System.out.println("finish!");

    }
}
2.2.3、Trie 树代码二
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;

/**
 * @description: Trie 树实现,字符不仅限于 a-z 26 个小写字符
 * @author: erlang
 * @since: 2020-09-20 10:56
 */
public class TrieTree {

    public static class Node {
        /**
         * 记录通过次数
         */
        public int pass;

        /**
         * 记录结尾次数
         */
        public int end;

        /**
         * 记录字符
         */
        public Map<Integer, Node> nexts;

        public Node() {
            pass = 0;
            end = 0;
            nexts = new HashMap<>(16);
        }
    }

    /**
     * 根节点
     */
    public Node root;

    public TrieTree() {
        root = new Node();
    }

    /**
     * 添加某个字符串,可以重复添加,每次加一
     *
     * @param word 添加的字符串
     */
    public void insert(String word) {
        if (Objects.isNull(word)) {
            return;
        }

        char[] chars = word.toCharArray();

        if (chars.length == 0) {
            return;
        }

        Node node = root;

        for (char ch : chars) {
            int index = ch;
            if (!node.nexts.containsKey(index)) {
                node.nexts.put(index, new Node());
            }
            node = node.nexts.get(index);
            node.pass++;
        }
        node.end++;
    }

    /**
     * 查询某个字符串在结构中的个数
     *
     * @param word 查询的字符串
     * @return 个数
     */
    public int search(String word) {
        if (Objects.isNull(word)) {
            return 0;
        }

        Node node = root;
        for (char ch : word.toCharArray()) {
            // 计算字符的索引位置
            int index = ch;
            if (!node.nexts.containsKey(index)) {
                return 0;
            }

            node = node.nexts.get(index);
        }
        return node.end;
    }

    /**
     * 删除某个字符串,可以重复删除,每次减一
     *
     * @param word 删除的字符串
     */
    public void delete(String word) {
        // 先判断是否存在
        if (search(word) == 0) {
            return;
        }

        Node node = root;
        node.pass--;
        for (char ch : word.toCharArray()) {
            int index = ch;
            if (--node.nexts.get(index).pass == 0) {
                // 如果当前节点的 pass = 0,则说明已经没有字符串的前缀是当前节点了
                // 后面的所有节点都可以直接删除,无需遍历了
                node.nexts.remove(index);
                return;
            }
            node = node.nexts.get(index);
        }
        node.end--;
    }

    /**
     * 查询有多少个字符串,是以 prefix 为前缀的
     *
     * @param prefix 查询的前缀字符串
     * @return 个数
     */
    public int prefixNumber(String prefix) {
        if (Objects.isNull(prefix)) {
            return 0;
        }

        Node node = root;
        for (char ch : prefix.toCharArray()) {
            int index = ch;
            if (!node.nexts.containsKey(index)) {
                return 0;
            }
            node = node.nexts.get(index);
        }
        return node.pass;
    }

    public static void main(String[] args) {
        int arrLen = 100;
        int strLen = 20;
        int testTimes = 100000;
        for (int i = 0; i < testTimes; i++) {
            String[] arr = TrieUtils.generateRandomStringArray(arrLen, strLen);
            TrieTree trie = new TrieTree();
            TrieUtils.newRight();
            for (String s : arr) {
                double decide = Math.random();
                if (decide < 0.25) {
                    trie.insert(s);
                    TrieUtils.insert(s);
                } else if (decide < 0.5) {
                    trie.delete(s);
                    TrieUtils.delete(s);
                } else if (decide < 0.75) {
                    int ans1 = trie.search(s);
                    int ans3 = TrieUtils.search(s);
                    if (ans1 != ans3) {
                        System.out.println("Oops13!");
                    }
                } else {
                    int ans1 = trie.prefixNumber(s);
                    int ans3 = TrieUtils.prefixNumber(s);
                    if (ans1 != ans3) {
                        System.out.println("Oops!");
                    }
                }
            }
        }
        System.out.println("finish!");
    }
}

2.3、时间复杂度

如果要在一组字符串汇总,频繁地查询某些字符串,用 Trie 树会非常高效。构建 Trie 树的过程,需要扫描所有的字符串,时间复杂度是 O(n)O(n)(n 表示所有字符串的长度和)。但是一旦构建成功之后,后续的查询操作会非常高效。

每次查询时,如果要查询的字符串长度是 k,那我们只需要比对大约 k 个节点,就能完成查询操作。跟原本那组字符串长度和个数没有任何关系。所以说,构建好 Trie 树后,在其中查找字符串的时间复杂度是 O(k)O(k),k 表示要查找的字符串的长度。

3、Trie 树与散列表、红黑树的异同

实际上,字符串的匹配问题,就是数据的查找问题。对于支持动态数据高效操作的数据结构,比如散列表、红黑树、跳跃表等。实际上,这些数据结构也可以显现在一组字符串中查找字符串的功能。这里用散列表和红黑树跟 Trie 树比较一下,看看他们各自的优缺点和应用场景。

在刚刚讲的这个场景,在一组字符串中查找字符串,Trie 树实际上表现得并不好。它对要处理的字符串有及其严苛的要求。

  1. 字符串中包含的字符集不能太大。如果字符集太大,那存储空间就会浪费很多。即便可以优化,也要付出牺牲查询、插入效率的代价。
  2. 要求字符串的前缀重合比较多,不然空间消耗会变大很多。
  3. 如果用 Trie 树解决问题,我们就要自己从零开始实现一个 Trie 树,还要保证没有 bug,这个在工程上是将简单问题复杂化,除非必须,一般不建议这样做。
  4. 通过指针穿起来的数据块是不连续的,而 Trie 树中用到了指针,所以对缓存并不友好,性能上也会打个折扣。

Trie 树只是不适合精确匹配查找,这种问题更适合用散列表或者红黑树来解决。Trie 树比较适合的是查找前缀匹配的字符串。

针对在一组字符串中查找字符串的问题,我们在工程中,更倾向于散列表或者红黑树。因为这两种数据结构,我们都不需要自己去实现,直接利用编程语言中提供的现成类库就行了。

4、小结

Trie 树是一种用于解决字符串快速匹配问题的数据结构。如果用来构建 Trie 树的这一组字符串中,前缀重复的情况不是很多,那 Trie 树这种数据结构总体上来讲是比较费内存的,是一种空间换时间的解决思路。

尽管比较耗费内存,但是对内存不敏感或者内存消耗在接受范围内的情况下,在 Trie 树中做字符串匹配还是非常高效的,时间复杂度是 O(k)O(k),k 表示要匹配的字符串的长度。

但是,Trie 树的优势并不是用它来做动态集合数据的查找,因为这个工作完全可以用更加合适的散列表或者红黑树来替代。Trie 树最有优势的是查找前缀匹配的字符串,比如搜索引擎总的关键词提示这个功能,就比较适合用它来解决,也是 Trie 树比较经典的应用场景。