深度学习神奇数据结构 Trie : 两种实现 Trie 的方式 / 真实运用面 / 与其他知识点的结合

1,236 阅读6分钟

本文正在参加「金石计划 . 瓜分6万现金大奖」

字典树

字典树又称为 Trie

是一种神奇的数据结构,不仅能够方便快速判断某个字符串是否存在,还能在 O(C)O(C) 复杂度内判断某个前缀是否存在。

今天我们将通过 44 道与 Trie 相关的题目来学习 Trie,掘友们加油 ~

208. 实现 Trie (前缀树) - 模板题

基本题意

Trie(发音类似 "try")或者说 前缀树 是一种树形数据结构,用于高效地存储和检索字符串数据集中的键。这一数据结构有相当多的应用情景,例如自动补完和拼写检查。

请你实现 Trie 类:

  • Trie() 初始化前缀树对象。
  • void insert(String word) 向前缀树中插入字符串 word 。
  • boolean search(String word) 如果字符串 word 在前缀树中,返回 true(即,在检索之前已经插入);否则,返回 false 。
  • boolean startsWith(String prefix) 如果之前已经插入的字符串 word 的前缀之一为 prefix ,返回 true ;否则,返回 false 。

示例:

输入
["Trie", "insert", "search", "search", "startsWith", "insert", "search"]
[[], ["apple"], ["apple"], ["app"], ["app"], ["app"], ["app"]]

输出
[null, null, true, false, true, null, true]

解释
Trie trie = new Trie();
trie.insert("apple");
trie.search("apple");   // 返回 True
trie.search("app");     // 返回 False
trie.startsWith("app"); // 返回 True
trie.insert("app");
trie.search("app");     // 返回 True

提示:

  • 1 <= word.length, prefix.length <= 2000
  • word 和 prefix 仅由小写英文字母组成
  • insert、search 和 startsWith 调用次数 总计 不超过 3 * 10410^4
Trie 树

TrieTrie 树(又叫「前缀树」或「字典树」)是一种用于快速查询「某个字符串/字符前缀」是否存在的数据结构。

其核心是使用「边」来代表有无字符,使用「点」来记录是否为「单词结尾」以及「其后续字符串的字符是什么」。

二维数组

一个朴素的想法是直接使用「二维数组」来实现 TrieTrie 树。

  • 使用二维数组 trie[]trie[] 来存储我们所有的单词字符。
  • 使用 indexindex 来自增记录我们到底用了多少个格子(相当于给被用到格子进行编号)。
  • 使用 count[]count[] 数组记录某个格子被「被标记为结尾的次数」(当 idxidx 编号的格子被标记了 nn 次,则有 cnt[idx]=ncnt[idx] = n)。

代码 :

class Trie {
    int N = 100009; // 直接设置为十万级
    int[][] trie;
    int[] count;
    int index;

    public Trie() {
        trie = new int[N][26];
        count = new int[N];
        index = 0;
    }
    
    public void insert(String s) {
        int p = 0;
        for (int i = 0; i < s.length(); i++) {
            int u = s.charAt(i) - 'a';
            if (trie[p][u] == 0) trie[p][u] = ++index;
            p = trie[p][u];
        }
        count[p]++;
    }
    
    public boolean search(String s) {
        int p = 0;
        for (int i = 0; i < s.length(); i++) {
            int u = s.charAt(i) - 'a';
            if (trie[p][u] == 0) return false;
            p = trie[p][u];
        }
        return count[p] != 0;
    }
    
    public boolean startsWith(String s) {
        int p = 0;
        for (int i = 0; i < s.length(); i++) {
            int u = s.charAt(i) - 'a';
            if (trie[p][u] == 0) return false;
            p = trie[p][u];
        }
        return true;
    }
}
  • 时间复杂度:TrieTrie 树的每次调用时间复杂度取决于入参字符串的长度。复杂度为 O(Len)O(Len)
  • 空间复杂度:二维数组的高度为 nn,字符集大小为 kk。复杂度为 O(nk)O(nk)
TrieNode

相比二维数组,更加常规的做法是建立 TrieNodeTrieNode 结构节点。

随着数据的不断插入,根据需要不断创建 TrieNodeTrieNode 节点。

代码:

class Trie {
    class TrieNode {
        boolean end;
        TrieNode[] tns = new TrieNode[26];
    }

    TrieNode root;
    public Trie() {
        root = new TrieNode();
    }

    public void insert(String s) {
        TrieNode p = root;
        for(int i = 0; i < s.length(); i++) {
            int u = s.charAt(i) - 'a';
            if (p.tns[u] == null) p.tns[u] = new TrieNode();
            p = p.tns[u]; 
        }
        p.end = true;
    }

    public boolean search(String s) {
        TrieNode p = root;
        for(int i = 0; i < s.length(); i++) {
            int u = s.charAt(i) - 'a';
            if (p.tns[u] == null) return false;
            p = p.tns[u]; 
        }
        return p.end;
    }

    public boolean startsWith(String s) {
        TrieNode p = root;
        for(int i = 0; i < s.length(); i++) {
            int u = s.charAt(i) - 'a';
            if (p.tns[u] == null) return false;
            p = p.tns[u]; 
        }
        return true;
    }
}
  • 时间复杂度:TrieTrie 树的每次调用时间复杂度取决于入参字符串的长度。复杂度为 O(Len)O(Len)
  • 空间复杂度:结点数量为 nn,字符集大小为 kk。复杂度为 O(nk)O(nk)
两种方式的对比

使用「二维数组」的好处是写起来飞快,同时没有频繁 newnew 对象的开销。但是需要根据数据结构范围估算我们的「二维数组」应该开多少行。

坏处是使用的空间通常是「TrieNodeTrieNode」方式的数倍,而且由于通常对行的估算会很大,导致使用的二维数组开得很大,如果这时候每次创建 TrieTrie 对象时都去创建数组的话,会比较慢,而且当样例多的时候甚至会触发 GCGC(因为 OJOJ 每测试一个样例会创建一个 TrieTrie 对象)。

因此还有一个小技巧是将使用到的数组转为静态,然后利用 indexindex 自增的特性在初始化 TrieTrie 时执行清理工作 & 重置逻辑。

这样的做法能够使评测时间降低一半,运气好的话可以得到一个与「TrieNodeTrieNode」方式差不多的时间。

class Trie {
    // 以下 static 成员独一份,被创建的多个 Trie 共用
    static int N = 100009; // 直接设置为十万级
    static int[][] trie = new int[N][26];
    static int[] count = new int[N];
    static int index = 0;

    // 在构造方法中完成重置 static 成员数组的操作
    // 这样做的目的是为减少 new 操作(无论有多少测试数据,上述 static 成员只会被 new 一次)
    public Trie() {
        for (int row = index; row >= 0; row--) {
            Arrays.fill(trie[row], 0);
        }
        Arrays.fill(count, 0);
        index = 0;
    }
    
    public void insert(String s) {
        int p = 0;
        for (int i = 0; i < s.length(); i++) {
            int u = s.charAt(i) - 'a';
            if (trie[p][u] == 0) trie[p][u] = ++index;
            p = trie[p][u];
        }
        count[p]++;
    }
    
    public boolean search(String s) {
        int p = 0;
        for (int i = 0; i < s.length(); i++) {
            int u = s.charAt(i) - 'a';
            if (trie[p][u] == 0) return false;
            p = trie[p][u];
        }
        return count[p] != 0;
    }
    
    public boolean startsWith(String s) {
        int p = 0;
        for (int i = 0; i < s.length(); i++) {
            int u = s.charAt(i) - 'a';
            if (trie[p][u] == 0) return false;
            p = trie[p][u];
        }
        return true;
    }
}

关于「二维数组」是如何工作 & 1e5 大小的估算

要搞懂为什么行数估算是 1e5,首先要搞清楚「二维数组」是如何工作的。

在「二维数组」中,我们是通过 indexindex 自增来控制使用了多少行的。

当我们有一个新的字符需要记录,我们会将 indexindex 自增(代表用到了新的一行),然后将这新行的下标记录到当前某个前缀的格子中。

举个🌰,假设我们先插入字符串 abc 这时候,前面三行会被占掉。

  • 第 0 行 a 所对应的下标有值,值为 1,代表前缀 a 后面接的字符串会被记录在下标为 1 的行内

  • 第 1 行 b 所对应的下标有值,值为 2,代表前缀 ab 后面接的字符串会被记录在下标为 2 的行内

  • 第 2 行 c 所对应的下标有值,值为 3,代表前缀 abc 后面接的字符串会被记录在下标为 3 的行内

当再插入 abcl 的时候,这时候会先定位到 abl 的前缀行(第 3 行),将 l 的下标更新为 4,代表 abcl 被加入前缀树,并且前缀 abcl 接下来会用到第 4 行进行记录。

但当插入 abl 的时候,则会定位到 ab 的前缀行(第 2 行),然后将 l 的下标更新为 5,代表 abl 被加入前缀树,并且前缀 abl 接下来会使用第 5 行进行记录。

当搞清楚了「二维数组」是如何工作之后,我们就能开始估算会用到多少行了,调用次数为 10410^4,传入的字符串长度为 10310^3,假设每一次的调用都是 insertinsert,并且每一次调用都会使用到新的 10310^3 行。那么我们的行数需要开到 10710^7

但由于我们的字符集大小只有 26,因此不太可能在 10410^4 次调用中都用到新的 10310^3 行。

而且正常的测试数据应该是 searchsearchstartsWithstartsWith 调用次数大于 insertinsert 才有意义的,一个只有 insertinsert 调用的测试数据,任何实现方案都能 AC。

因此我设定了 10510^5 为行数估算,当然直接开到 10610^6 也没有问题。

关于 Trie 的应用面

首先,在纯算法领域,前缀树算是一种较为常用的数据结构。

不过如果在工程中,不考虑前缀匹配的话,基本上使用 hash 就能满足。

如果考虑前缀匹配的话,工程也不会使用 Trie 。

一方面是字符集大小不好确定(题目只考虑 26 个字母,字符集大小限制在较小的 26 内)因此可以使用 Trie,但是工程一般兼容各种字符集,一旦字符集大小很大的话,Trie 将会带来很大的空间浪费。

另外,对于个别的超长字符 Trie 会进一步变深。

这时候如果 Trie 是存储在硬盘中,Trie 结构过深带来的影响是多次随机 IO,随机 IO 是成本很高的操作。

同时 Trie 的特殊结构,也会为分布式存储将会带来困难。

因此在工程领域中 Trie 的应用面不广。

至于一些诸如「联想输入」、「模糊匹配」、「全文检索」的典型场景在工程主要是通过 ES (ElasticSearch) 解决的。

而 ES 的实现则主要是依靠「倒排索引」


648. 单词替换 - 模板级运用

基本题意

在英语中,我们有一个叫做 词根(root) 的概念,可以词根后面添加其他一些词组成另一个较长的单词——我们称这个词为 继承词(successor)。例如,词根an,跟随着单词 other(其他),可以形成新的单词 another(另一个)。

现在,给定一个由许多词根组成的词典 dictionary 和一个用空格分隔单词形成的句子 sentence。你需要将句子中的所有继承词用词根替换掉。如果继承词有许多可以形成它的词根,则用最短的词根替换它。

你需要输出替换之后的句子。

示例 1:

输入:dictionary = ["cat","bat","rat"], sentence = "the cattle was rattled by the battery"

输出:"the cat was rat by the bat"

示例 2:

输入:dictionary = ["a","b","c"], sentence = "aadsfasf absbs bbab cadsfafs"

输出:"a a b c"

提示:

  • 1<=dictionary.length <=10001 <= dictionary.length <= 1000
  • 1<=dictionary[i].length<=1001 <= dictionary[i].length <= 100
  • dictionary[i] 仅由小写字母组成。
  • 1<=sentence.length<=1061 <= sentence.length <= 10^6
  • sentence 仅由小写字母和空格组成。
  • sentence 中单词的总量在范围 [1,1000][1, 1000] 内。
  • sentence 中每个单词的长度在范围 [1,1000][1, 1000] 内。
  • sentence 中单词之间由一个空格隔开。
  • sentence 没有前导或尾随空格。
基本分析

这是一道 Trie 的模板题,还不了解 Trie 的同学可以先看前置 🧀:【设计数据结构】实现 Trie (前缀树)

前置 🧀 通过图解形式讲解了 Trie 的结构与原理,以及提供了两种实现 Trie 的方式。

回到本题,为了方便,我们令 dsdictionary,令 ssentence

二维数组

一个比较习惯的做法,是使用「二维数组」来实现 Trie,配合 static 优化,可以有效控制 new 的次数,耗时相对稳定。

考虑两个 Trie 的基本操作:

  • add 操作:变量入参字符串 s,将字符串中的每位字符映射到 [0,25][0, 25],同时为了能够方便查询某个字符串(而不只是某个前缀)是否曾经存入过 Trie 中,额外使用一个布尔数组 isEnd 记录某个位置是否为单词结尾。
  • query 操作:

至于二维数组的大小估算,可以直接开成 N×CN \times C,其中 NN 为要插入到 Trie 中的字符总数,CC 为对应的字符集大小。在 N×CN \times C 没有 MLE 风险时,可以直接开这么多;而当 N×CN \times C 较大(超过 1e71e7,甚至 1e81e8 时),可以适当将 N×CN \times C 中的 NN 减少,使得总空间在 1e71e7 左右,因为实际上由于二维数组中的某些行中会存储一个字符以上,实际上我们用不到这么多行。

代码(不使用 static 优化,耗时增加十倍):

class Solution {
    static int N = 100000, M = 26;
    static int[][] tr = new int[N][M];
    static boolean[] isEnd = new boolean[N * M];
    static int idx;
    void add(String s) {
        int p = 0;
        for (int i = 0; i < s.length(); i++) {
            int u = s.charAt(i) - 'a';
            if (tr[p][u] == 0) tr[p][u] = ++idx;
            p = tr[p][u];
        }
        isEnd[p] = true;
    }
    String query(String s) {
        for (int i = 0, p = 0; i < s.length(); i++) {
            int u = s.charAt(i) - 'a';
            if (tr[p][u] == 0) break;
            if (isEnd[tr[p][u]]) return s.substring(0, i + 1);
            p = tr[p][u];
        }
        return s;
    }
    public String replaceWords(List<String> ds, String s) {
        for (int i = 0; i <= idx; i++) {
            Arrays.fill(tr[i], 0);
            isEnd[i] = false;
        }
        for (String d : ds) add(d);
        StringBuilder sb = new StringBuilder();
        for (String str : s.split(" ")) sb.append(query(str)).append(" ");
        return sb.substring(0, sb.length() - 1);
    }
}
  • 时间复杂度:令 n=i=0ds.length1ds[i].lengthn = \sum_{i = 0}^{ds.length - 1} ds[i].lengthmms 长度,复杂度为 O(n+m)O(n + m)
  • 空间复杂度:O(n×C)O(n \times C),其中 C=26C = 26 为字符集大小
TrieNode

另外一个能够有效避免「估数组大小」操作的方式,是使用 TrieNode 的方式实现 Trie:每次使用到新的格子再触发 new 操作。

至于为什么有了 TrieNode 的方式,我还是会采用「二维数组」优先的做法,在 知乎 上有同学问过我类似的问题,只不过原问题是「为什么有了动态开点线段树,直接 build4n4n 空间的做法仍有意义」,这对应到本题使用「二维数组」还是「TrieNode」是一样的道理:

除非某些语言在启动时,采用虚拟机的方式,并且预先分配了足够的内存,否则所有的 new 操作都需要反映到 os 上,而在 linux 分配时需要遍历红黑树,因此即使是总空间一样,一次性的 new 要比多次小空间的 new 更省时间,同时集中性的 new 也比分散性的 new 操作要更快,这也就是为什么我们不无脑使用 TrieNode 的原因。

代码:

class Solution {
    class Node {
        boolean isEnd;
        Node[] tns = new Node[26];
    }
    Node root = new Node();
    void add(String s) {
        Node p = root;
        for (int i = 0; i < s.length(); i++) {
            int u = s.charAt(i) - 'a';
            if (p.tns[u] == null) p.tns[u] = new Node();
            p = p.tns[u];
        }
        p.isEnd = true;
    }
    String query(String s) {
        Node p = root;
        for (int i = 0; i < s.length(); i++) {
            int u = s.charAt(i) - 'a';
            if (p.tns[u] == null) break;
            if (p.tns[u].isEnd) return s.substring(0, i + 1);
            p = p.tns[u];
        }
        return s;
    }
    public String replaceWords(List<String> ds, String s) {
        for (String str : ds) add(str);
        StringBuilder sb = new StringBuilder();
        for (String str : s.split(" ")) sb.append(query(str)).append(" ");
        return sb.substring(0, sb.length() - 1);
    }
}
  • 时间复杂度:令 n=i=0ds.length1ds[i].lengthn = \sum_{i = 0}^{ds.length - 1} ds[i].lengthmms 长度,复杂度为 O(n+m)O(n + m)
  • 空间复杂度:O(n×C)O(n \times C),其中 C=26C = 26 为字符集大小

421. 数组中两个数的最大异或值 - 结合贪心的运用题

基本题意

给你一个整数数组 numsnums ,返回 nums[i] XOR nums[j] 的最大运算结果,其中 0ij<n0 ≤ i ≤ j < n

进阶:你可以在 O(n)O(n) 的时间解决这个问题吗?

示例 1:

输入:nums = [3,10,5,25,2,8]

输出:28

解释:最大运算结果是 5 XOR 25 = 28.

示例 2:

输入:nums = [0]

输出:0

示例 3:

输入:nums = [2,4]

输出:6

示例 4:

输入:nums = [8,10,2]

输出:10

示例 5:

输入:nums = [14,70,53,83,49,91,36,80,92,51,66,70]

输出:127

提示:

  • 1<=nums.length<=2×1041 <= nums.length <= 2 \times 10^4
  • 0<=nums[i]<=23110 <= nums[i] <= 2^{31} - 1
基本分析

要求得数组 nums 中的「最大异或结果」,假定 nums[i]nums[i]nums[j]nums[j] 异或可以取得最终结果。

由于异或计算「每位相互独立」(又称为不进位加法),同时具有「相同值异或结果为 00,不同值异或结果为 11」的特性。

因此对于 nums[j]nums[j] 而言,可以从其二进制表示中的最高位开始往低位找,尽量让每一位的异或结果为 11,这样找到的 nums[i]nums[i]nums[j]nums[j] 的异或结果才是最大的。

具体的,我们需要先将 nums 中下标范围为 [0,j][0, j] 的数(二进制表示)加入 TrieTrie 中,然后每次贪心的匹配每一位(优先匹配与之不同的二进制位)。

证明

由于我们会从前往后扫描 nums 数组,因此 nums[j]nums[j] 必然会被处理到,所以我们只需要证明,在选定 nums[j]nums[j] 的情况下,我们的算法能够在 [0,j][0, j] 范围内找到 nums[i]nums[i] 即可。

假定我们算法找出来的数值与 nums[j]nums[j] 的异或结果为 xx,而真实的最优异或结果为 yy

接下来需要证得 xxyy 相等。

由于找的是「最大异或结果」, 而 xx 是一个合法值,因此我们天然有 xyx \leq y

然后利用反证法证明 xyx \geq y,假设 xyx \geq y 不成立,即有 x<yx < y,那么从两者的二进制表示的高位开始找,必然能找到第一位不同:yy 的「不同位」的值为 11,而 xx 的「不同位」的值为 00

那么对应到选择这一个「不同位」的逻辑:能够选择与 nums[j]nums[j] 该位不同的值,使得该位的异或结果为 11,但是我们的算法选择了与 nums[j]nums[j] 该位相同的值,使得该位的异或结果为 00

这与我们的算法逻辑冲突,因此必然不存在这样的「不同位」。即 x<yx < y 不成立,反证 xyx \geq y 成立。

得证 xxyy 相等。

Trie 数组实现

可以使用数组来实现 TrieTrie,但由于 OJ 每跑一个样例都会创建一个新的对象,因此使用数组实现,相当于每跑一个数据都需要 new 一个百万级别的数组,会 TLE 。

因此这里使用数组实现必须要做的一个优化是:使用 static 来修饰 TrieTrie 数组,然后在初始化时做相应的清理工作。

担心有不熟 Java 的同学,在代码里添加了相应注释说明。

代码:

class Solution {
    // static 成员整个类独一份,只有在类首次加载时才会创建,因此只会被 new 一次
    static int N = (int)1e6;
    static int[][] trie = new int[N][2];
    static int idx = 0;
    // 每跑一个数据,会被实例化一次,每次实例化的时候被调用,做清理工作
    public Solution() {
        for (int i = 0; i <= idx; i++) {
            Arrays.fill(trie[i], 0);
        }
        idx = 0;
    }
    void add(int x) {
        int p = 0;
        for (int i = 31; i >= 0; i--) {
            int u = (x >> i) & 1;
            if (trie[p][u] == 0) trie[p][u] = ++idx;
            p = trie[p][u];
        }
    }
    int getVal(int x) {
        int ans = 0;
        int p = 0;
        for (int i = 31; i >= 0; i--) {
            int a = (x >> i) & 1, b = 1 - a;
            if (trie[p][b] != 0) {
                ans |= (b << i);
                p = trie[p][b];
            } else {
                ans |= (a << i);
                p = trie[p][a];
            }
        }
        return ans;
    }
    public int findMaximumXOR(int[] nums) {
        int ans = 0;
        for (int i : nums) {
            add(i);
            int j = getVal(i);
            ans = Math.max(ans, i ^ j);
        }
        return ans;
    }
}
  • 时间复杂度:O(n)O(n)
  • 空间复杂度:O(1e6)O(1e6)
Trie 类实现

相比于使用 static 来优化,一个更好的做法是使用类来实现 TrieTrie,这样可以真正做到「按需分配」内存,缺点是会发生不确定次数的 new

代码:

class Solution {
    class Node {
        Node[] ns = new Node[2];
    }
    Node root = new Node();
    void add(int x) {
        Node p = root;
        for (int i = 31; i >= 0; i--) {
            int u = (x >> i) & 1;
            if (p.ns[u] == null) p.ns[u] = new Node();
            p = p.ns[u];
        }
    }
    int getVal(int x) {
        int ans = 0;
        Node p = root;
        for (int i = 31; i >= 0; i--) {
            int a = (x >> i) & 1, b = 1 - a;
            if (p.ns[b] != null) {
                ans |= (b << i);
                p = p.ns[b];
            } else {
                ans |= (a << i);
                p = p.ns[a];
            }
        }
        return ans;
    }
    public int findMaximumXOR(int[] nums) {
        int ans = 0;
        for (int i : nums) {
            add(i);
            int j = getVal(i);
            ans = Math.max(ans, i ^ j);
        }
        return ans;
    }
}
  • 时间复杂度:O(n)O(n)
  • 空间复杂度:O(n)O(n)

1707. 与数组中元素的最大异或值 - 结合「二分」

给你一个由非负整数组成的数组 numsnums 。另有一个查询数组 queriesqueries,其中 queries[i]=[xi,mi]queries[i] = [xi, mi]

ii 个查询的答案是 xix_i 和任何 numsnums 数组中不超过 mim_i 的元素按位异或(XOR)得到的最大值。

换句话说,答案是 max(nums[j] \{XOR} x_i) ,其中所有 jj 均满足 nums[j]<=minums[j] <= m_i 。如果 numsnums 中的所有元素都大于 mim_i,最终答案就是 1-1

返回一个整数数组 answeranswer 作为查询的答案,其中 answer.length==queries.lengthanswer.length == queries.lengthanswer[i]answer[i] 是第 ii 个查询的答案。

示例 1:

输入:nums = [0,1,2,3,4], queries = [[3,1],[1,3],[5,6]]

输出:[3,3,7]

解释:
1) 0 和 1 是仅有的两个不超过 1 的整数。0 XOR 3 = 31 XOR 3 = 2 。二者中的更大值是 3 。
2) 1 XOR 2 = 3.
3) 5 XOR 2 = 7.

示例 2:

输入:nums = [5,2,4,6,6,3], queries = [[12,4],[8,1],[6,3]]

输出:[15,-1,5]

提示:

  • 1 <= nums.length, queries.length <= 10510^5
  • queries[i].length == 2
  • 0 <= nums[j], xi, mi <= 10910^9
基本分析

在做本题之前,请先确保已经完成 421. 数组中两个数的最大异或值

这种提前给定了所有询问的题目,我们可以运用离线思想(调整询问的回答顺序)进行求解。

对于本题有两种离线方式可以进行求解。

普通 Trie

第一种方法基本思路是:不一次性地放入所有数,而是每次将需要参与筛选的数字放入 TrieTrie,再进行与 [421. 数组中两个数的最大异或值] 类似的贪心查找逻辑。

具体的,我们可以按照下面的逻辑进行处理:

  1. nums 进行「从小到大」进行排序,对 queries 的第二维进行「从小到大」排序(排序前先将询问原本的下标映射关系存下来)。
  2. 按照排序顺序处理所有的 queries[i]
    1. 在回答每个询问前,将小于等于 queries[i][1] 的数值存入 TrieTrie。由于我们已经事先对 nums 进行排序,因此这个过程只需要维护一个在 nums 上有往右移动的指针即可。
    2. 然后利用贪心思路,查询每个 queries[i][0] 所能找到的最大值是多少,计算异或和(此过程与 [421. 数组中两个数的最大异或值] 一致)。
    3. 找到当前询问在原询问序列的下标,将答案存入。

代码:

class Solution {
    static int N = (int)1e5 * 32;
    static int[][] trie = new int[N][2];
    static int idx = 0;
    public Solution() {
        for (int i = 0; i <= idx; i++) {
            Arrays.fill(trie[i], 0);
        }
        idx = 0;
    }
    void add(int x) {
        int p = 0;
        for (int i = 31; i >= 0; i--) {
            int u = (x >> i) & 1;
            if (trie[p][u] == 0) trie[p][u] = ++idx;
            p = trie[p][u];
        }
    }
    int getVal(int x) {
        int ans = 0;
        int p = 0;
        for (int i = 31; i >= 0; i--) {
            int a = (x >> i) & 1, b = 1 - a;
            if (trie[p][b] != 0) {
                p = trie[p][b];
                ans = ans | (b << i);
            } else {
                p = trie[p][a];
                ans = ans | (a << i);
            } 
        }
        return ans ^ x;
    }
    public int[] maximizeXor(int[] nums, int[][] qs) {
        int m = nums.length, n = qs.length;

        // 使用哈希表将原本的顺序保存下来
        Map<int[], Integer> map = new HashMap<>();
        for (int i = 0; i < n; i++) map.put(qs[i], i);

        // 将 nums 与 queries[x][1] 进行「从小到大」进行排序
        Arrays.sort(nums);
        Arrays.sort(qs, (a, b)->a[1]-b[1]);

        int[] ans = new int[n];
        int loc = 0; // 记录 nums 中哪些位置之前的数已经放入 Trie
        for (int[] q : qs) {
            int x = q[0], limit = q[1];
            // 将小于等于 limit 的数存入 Trie
            while (loc < m && nums[loc] <= limit) add(nums[loc++]);
            if (loc == 0) {
                ans[map.get(q)] = -1;    
            } else {
                ans[map.get(q)] = getVal(x);    
            }
        }
        return ans;
    }
}
  • 时间复杂度:令 nums 的长度为 mqs 的长度为 n。两者排序的复杂度为 O(mlogm)O(m\log{m})O(nlogn)O(n\log{n});将所有数插入 TrieTrie 和从 TrieTrie 中查找的复杂度均为 O(Len)O(Len)LenLen3232。 整体复杂度为 O(mlogm+nlogn+(m+n)Len)O(m\log{m} + n\log{n} + (m + n) * Len) = O(mmax(logm,Len)+nmax(logn,Len))O(m * \max(\log{m}, Len) + n * \max(\log{n}, Len))
  • 空间复杂度:O(C)O(C)。其中 CC 为常数,固定为 1e53221e5 * 32 * 2
计数 Trie & 二分

另外一个比较「增加难度」的做法是,将整个过程翻转过来:一次性存入所有的 TrieTrie 中,然后每次将不再参与的数从 TrieTrie 中移除。相比于解法一,这就要求我们为 TrieTrie 增加一个「删除/计数」功能,并且需要实现二分来找到移除元素的上界下标是多少。

具体的,我们可以按照下面的逻辑进行处理:

  1. nums 进行「从大到小」进行排序,对 queries 的第二维进行「从大到小」排序(排序前先将询问原本的下标映射关系存下来)。
  2. 按照排序顺序处理所有的 queries[i]
    1. 在回答每个询问前,通过「二分」找到在 nums 中第一个满足「小于等于 queries[i][1] 的下标在哪」,然后将该下标之前的数从 TrieTrie 中移除。同理,这个过程我们需要使用一个指针来记录上一次删除的下标位置,避免重复删除。
    2. 然后利用贪心思路,查询每个 queries[i][0] 所能找到的最大值是多少。注意这是要判断当前节点是否有被计数,如果没有则返回 1-1
    3. 找到当前询问在原询问序列的下标,将答案存入。

代码:

class Solution {
    static int N = (int)1e5 * 32;
    static int[][] trie = new int[N][2];
    static int[] cnt = new int[N];
    static int idx = 0;
    public Solution() {
        for (int i = 0; i <= idx; i++) {
            Arrays.fill(trie[i], 0);
            cnt[i] = 0;
        }
        idx = 0;
    }
    // 往 Trie 存入(v = 1)/删除(v = -1) 某个数 x
    void add(int x, int v) {
        int p = 0;
        for (int i = 31; i >= 0; i--) {
            int u = (x >> i) & 1;
            if (trie[p][u] == 0) trie[p][u] = ++idx;
            p = trie[p][u];
            cnt[p] += v;
        }
    }
    int getVal(int x) {
        int ans = 0;
        int p = 0;
        for (int i = 31; i >= 0; i--) {
            int a = (x >> i) & 1, b = 1 - a;
            if (cnt[trie[p][b]] != 0) {
                p = trie[p][b];
                ans = ans | (b << i);
            } else if (cnt[trie[p][a]] != 0) {
                p = trie[p][a];
                ans = ans | (a << i);
            } else {
                return -1;
            }
        }
        return ans ^ x;
    }
    public int[] maximizeXor(int[] nums, int[][] qs) {
        int n = qs.length;
        
        // 使用哈希表将原本的顺序保存下来
        Map<int[], Integer> map = new HashMap<>();
        for (int i = 0; i < n; i++) map.put(qs[i], i);

        // 对两者排降序
        sort(nums);
        Arrays.sort(qs, (a, b)->b[1]-a[1]);

        // 将所有数存入 Trie
        for (int i : nums) add(i, 1);

        int[] ans = new int[n];
        int left = -1; // 在 nums 中下标「小于等于」left 的值都已经从 Trie 中移除
        for (int[] q : qs) {
            int x = q[0], limit = q[1];
            // 二分查找到待删除元素的右边界,将其右边界之前的所有值从 Trie 中移除。
            int right = getRight(nums, limit);            
            for (int i = left + 1; i < right; i++) add(nums[i], -1);
            left = right - 1;
            ans[map.get(q)] = getVal(x);
        }
        return ans;
    }
    // 二分找到待删除的右边界
    int getRight(int[] nums, int limit) {
        int l = 0, r = nums.length - 1;
        while (l < r) {
            int mid = l + r >> 1;
            if (nums[mid] <= limit) {
                r = mid;
            } else {
                l = mid + 1;
            }
        }
        return nums[r] <= limit ? r : r + 1;
    }
    // 对 nums 进行降序排序(Java 没有 Api 直接支持对基本类型 int 排倒序,其他语言可忽略)
    void sort(int[] nums) {
        Arrays.sort(nums);
        int l = 0, r = nums.length - 1;
        while (l < r) {
            int c = nums[r];
            nums[r--] = nums[l];
            nums[l++] = c;
        }
    }
}
  • 时间复杂度:令 nums 的长度为 mqs 的长度为 n,常数 Len=32Len = 32。两者排序的复杂度为 O(mlogm)O(m\log{m})O(nlogn)O(n\log{n});将所有数插入 TrieTrie 的复杂度为 O(mLen)O(m * Len);每个查询都需要经过「二分」找边界,复杂度为 O(nlogm)O(n\log{m});最坏情况下所有数都会从 TrieTrie 中被标记删除,复杂度为 O(mLen)O(m * Len)。 整体复杂度为 O(mlogm+nlogn+nlogm+mLen)O(m\log{m} + n\log{n} + n\log{m} + mLen) = O(mmax(logm,Len)+nmax(logm,logn))O(m * \max(\log{m}, Len) + n * \max(\log{m}, \log{n}))
  • 空间复杂度:O(C)O(C)。其中 CC 为常数,固定为 1e53231e5 * 32 * 3

总结

经过对 44 道与 Trie 相关题目的深度讲解,相信大家对 Trie 都有所认识。

其中题目一通过图解的方式向大家展示 Trie 的相关结构,并分享实现 Trie 的两种方式(模板),还向大家分享了 Trie 真实的运用场景;

随后通过一道题目,加强对模板代码的运用;

最后再通过 Trie 与「贪心」以及「二分」的结合来加深对 Trie 的理解。

希望大家能够每道题目都做三遍以上。