线段树模板

176 阅读6分钟

线段树是一种用于维护 区间信息 的数据结构,本文将结合一道经典例题分析它的引入背景、详细概念和具体实现,作为笔者个人的模板归档。关于线段树细节的分析本文不做介绍,请读者自行搜索网上资料。

1. 引入

先来看一道线段树模板题:307. 区域和检索 - 数组可修改 - 力扣(LeetCode)

我们简单总结下题意,给定一个数组 nums,要求实现一种数据结构 NumArray,完成两类查询:

  • update:更新数组 nums 某个下标对应的值
  • sumRange:查询 nums 某段区间 [left,right][left, right] 所有元素和

最后分析题目难点。给定的 nums 数组长度不超过 31043*10^4,两类查询 updatesumRange 的调用次数均不大于 31043*10^4,所以单次 updatesumRange 的时间复杂度必须小于 O(n)O(n)

显然,本题可以用线段树来解决,其他常见解法还有树状数组、分块。为什么笔者要专门介绍线段树?因为线段树的设计思想更经典、更通用,往往是最推荐的解法,尽管它的代码看起来比较长。

2. 概念

线段树是一种二叉树,用于高效地存储和操作区间信息。线段树的核心是通过二分的思想,将数组下标空间反复二分成多个区间,形成二叉树的结构,在树的节点上存储信息。

例如,对于第 1 节中的模板题,假设给定数组 nums[1, 3, 5, 7, 9, 11],则可以构造一个线段树如下:

              [0, 5]
             /       \
        [0, 2]       [3, 5]
       /     \       /     \
    [0, 1]  (2, 2) [3, 4]  (5, 5)
    /    \         /    \
(0, 0)  (1, 1)  (3, 3)  (4, 4)

小括号为叶节点,存储数组元素信息,即 nums[left]nums[right] 的值。中括号为非叶节点,存储区间信息,即 [left,right][left, right] 的区间和。

利用这种二叉树结构,线段树可以在 O(logn)O(\log n) 的时间复杂度内完成第 1 节中模板题要求的两种查询。事实上,线段树可以在 O(logn)O(\log n) 的时间复杂度内完成 单点修改区间修改区间查询 等操作。

综上所述,线段树常常用于解决这样一类问题。给定一个数组,需要频繁进行以下操作:

  • 区间修改:对区间 [left,right][left, right] 内的每个数进行一次相同操作,例如都加上 kk。此外,单点修改 可以作为一种特殊的区间修改处理。
  • 区间查询:查询区间 [left,right][left, right] 的信息。常见的查询包括区间求和,求区间最大值,求区间最小值。

3. 实现

3.1. 数组估点

因为线段树为一棵「满二叉树」,所以可以使用数组实现。定义数组 d 存储区间信息(即区间和),定义数组 b 存储懒惰标记。数组 d 的长度可以粗略地用 N = 4 * nums.length 来计算。

d[p] 表示 nums 数组中 [left,right][left, right] 的区间和,其区间中点为 mid = left + (right - 1 >> 1),则 d[p] 的左右两个子区间 [left,mid][left, mid][mid+1,right][mid + 1, right] 的对应节点分别为 d[p << 1]d[(p << 1) | 1]

class SegmentTree {
    private int[] d;    // 数组 d 记录当前节点表示的区间信息
    private int[] b;    // 数组 b 为懒惰标记

    public void update(int p, int left, int right, int start, int end, int val) {
        if (start <= left && right <= end) {
            d[p] += (right - left + 1) * val;
            b[p] += val;
            return;
        }
        int mid = left + (right - 1 >> 1);
        pushDown(p, mid - left + 1, right - mid);
        if (start <= mid) {
            update(p << 1, left, mid, start, end, val); // (p << 1) 表示左子树的下标
        }
        if (end > mid) {
            update((p << 1) | 1, mid + 1, right, start, end, val);  // ((p << 1) | 1) 表示右子树的下标
        }
        pushUp(p);
    }

    public int query(int p, int left, int right, int start, int end) {
        if (start <= left && right <= end) {
            return d[p];
        }
        int mid = left + (right - 1 >> 1);
        int res = 0;
        pushDown(p, mid - left + 1, right - mid);
        if (start <= mid) {
            res += query(p << 1, left, mid, start, end);
        }
        if (end > mid) {
            res += query((p << 1) | 1, mid + 1, right, start, end);
        }
        return res;
    }

    private void pushDown(int p, int leftLen, int rightLen) {
        if (d[p] == 0) {
            return; // 当前节点的懒惰标记为零,无需更新
        }
        int leftIndex = p << 1;
        int rightIndex = (p << 1) | 1;
        d[leftIndex] += b[p] * leftLen;
        d[rightIndex] += b[p] * rightLen;
        b[leftIndex] += b[p];
        b[rightIdex] += b[q];
        b[p] = 0;   // 更新完当前节点,清除懒惰标记
    }

    private void pushUp(int p) {
        d[p] = d[p << 1] + d[(p << 1) | 1];
    }
}

3.2. 动态开点

上一节中提到,数组估点实现线段树时,数组 d 的长度可以用数组 nums 长度的四倍来估算。如果 nums 的长度很大,而操作次数却不多,使用数组实现会有大量的空间浪费,甚至 OOM。

为了节省空间,可以采用「动态开点」的方式实现线段树。动态开点的优势在于,不需要事前构造空树,而是在插入操作 update 和查询操作 query 时根据访问需要进行「开点」操作。由于我们不保证插入和查询都是连续的,因此对于父节点 u 的左右两个子区间而言,我们不能通过 u < 1(u << 1) | 1 的固定方式进行访问,而要存储节点 u 的两个子区间对应的节点,分别记为 lcrc 属性。

class SegmentTree {
    private class Node {
        int val, add;
        Node lc, rc;
    }

    /**
     * 对当前节点中与目标区间 [start, end] 相交的部分进行加上 val 的区间修改,也用于建树
     * 
     * @param node  当前根节点
     * @param left  当前节点表示的区间左端点
     * @param right 当前节点表示的区间右端点
     * @param start 目标区间左端点
     * @param end   目标区间右端点
     * @param val   区间修改中要加上的值
     */
    public void update(Node node, int left, int right, int start, int end, int val) {
        if (start <= left && right <= end) {    // 当前节点表示的区间在区间范围内,对全区间进行修改
            node.val += (right - left + 1) * val;
            node.add += val;
            return;
        }
        int mid = left + (right - 1 >> 1);
        pushDown(node, mid - left + 1, right - mid); // 下推懒惰标记
        if (start <= mid) { // 当前节点的左子树与目标区间相交,更新左子树
            update(node.lc, left, mid, start, end, val);
        }
        if (end > mid) {    // 当前节点的右子树与目标区间相交,更新右子树
            update(node.rc, mid + 1, right, start, end, val);
        }
        pushUp(node);   // 上推,将更新后的左右子树的值同步到当前节点
    }

    public int query(Node node, int left, int right, int start, int end) {
        if (start <= left && right <= end) {
            return node.val;
        }
        int mid = left + (right - 1 >> 1);
        int res = 0;
        pushDown(node, mid - left + 1, right - mid);
        if (start <= mid) {
            res += query(node.lc, left, mid, start, end);
        }
        if (end > mid) {
            res += query(node.rc, mid + 1, right, start, end);
        }
        return res;
    }

    private void pushDown(Node node, int leftLen, int rightLen) {
        if (node.add == 0) {
            return; // 当前节点的懒惰标记为零,无需更新
        }
        // 更新左右子树的区间信息
        node.lc.val += node.add * leftLen;
        node.rc.val += node.add * rightLen;
        // 更新左右子树的懒惰标记
        node.lc.add += node.add;
        node.rc.add += node.add;

        node.add = 0;   // 更新完当前节点,清除懒惰标记
    }

    private void pushUp(Node node) {
        node.val = node.lc.val + node.rc.val;
    }
}