数据结构扩展(二) —— 线段树

253 阅读3分钟

用途

给定一个长度为n的序列,需要:

  • 频繁的求某个区间的最值
  • 频繁的更新某个区间的部分/所有值

「线段树」可以解决这类需要维护区间信息的问题,可以在O(logN)的时间复杂度哪实现:

  • 单点修改O(logN)
  • 区间修改(*需要用到lazy propogation来优化到logN,完全不在面试范围内)
  • 区间查询O(logN)区间求和/区间最大值/区间最小值/区间最小公倍数/区间最大公因数

image.png

  • 叶子节点存储输入的数组元素
  • 每一个内部节点表示某些叶子节点的合并 (merge)
    • 合并的方法可能会因问题而异。如上图所示,合并指的是某个节点之下的所有叶子节点的和(求区间最大最小的时候就不是求和了,而是找最大或者最小
  • 对于下标i的节点,其左孩子为 2*i+1,右孩子为2*i+2,父节点为floor((i - 1)/2)
  • 默认method为:
    • update():单点更新
    • rangeSum()
    • rangeMax()
    • rangeMin()
    • *rangeUpdate()


线段树的两种实现

❤️ 普通线段树 - tree based

image.png

class SegementTreeNode:
    def __init__(self, start: int, end: int):
        self.left = self.right = None  # node的左右子节点
        self.start, self.end = start, end  # node的范围(闭区间)
        self.sum = 0  # sum(nums[l,r])


class SegementTree:
    def __init__(self, nums):
        self.nums = nums
        self.root = self.build_tree(0, len(nums) - 1)
    
    def build_tree(self, start, end):
        if start > end:
            return None
        node = SegementTreeNode(start, end)
        if start == end:
            node.sum = self.nums[start]
        else:
            mid = (start + end) // 2
            node.left = self.build_tree(start, mid)
            node.right = self.build_tree(mid + 1, end)
            node.sum = node.left.sum + node.right.sum
        return node
    
    def update(self, node, index, val):
        if node.start == node.end:
            node.sum = val
        else:
            mid = (node.start + node.end) // 2
            if index <= mid:
                self.update(node.left, index, val)
            else:
                self.update(node.right, index, val)
            node.sum = node.left.sum + node.right.sum
    
    def range_sum(self, node, start, end):
        if start > end:
            return 0
        if node.start == start and node.end == end:
            return node.sum
        mid = (node.start + node.end) // 2
        if end <= mid:
            return self.range_sum(node.left, start, end)
        elif start > mid:
            return self.range_sum(node.right, start, end)
        else:
            return self.range_sum(node.left, start, mid) + self.range_sum(node.right, mid + 1, end)

*ZKW线段树

PS:适合contest,不适合面试

image.png

class ZKWSegementTree:
    def __init__(self, nums: List[int]):
        self.n = len(nums)
        self.st = [0] * (2 * self.n)
        for i in range(self.n, self.n * 2):  # leaf node是原始数组的值
            self.st[i] = nums[i - self.n]
        for i in range(self.n - 1, 0, -1):  # parent = 两个child的sum
            self.st[i] = self.st[i * 2] + self.st[i * 2 + 1]
    
    def update(self, i: int, val: int) -> None:
        diff = val - self.st[i + self.n]
        i += self.n
        while i > 0:
            self.st[i] += diff
            i //= 2
    
    def rangeSum(self, l, r) -> int:
        res = 0
        l += self.n
        r += self.n
        while l <= r:
            if l % 2 == 1:  # st[l]是left child
                res += self.st[l]
                l += 1
            if r % 2 == 0:  # st[r]是right child
                res += self.st[r]
                r -= 1
            l //= 2
            r //= 2
        return res


题目

307. 区域和检索 - 数组可修改(Median)

image.png

Solution:

  • 线段树(tree-based)模版,略

Code:

class SegementTreeNode:
    def __init__(self, start: int, end: int):
        self.left = self.right = None  # node的左右子节点
        self.start, self.end = start, end  # node的范围(闭区间)
        self.sum = 0  # sum(nums[l,r])


class SegementTree:
    def __init__(self, nums):
        self.nums = nums
        self.root = self.build_tree(0, len(nums) - 1)
    
    def build_tree(self, start, end):
        if start > end:
            return None
        node = SegementTreeNode(start, end)
        if start == end:
            node.sum = self.nums[start]
        else:
            mid = (start + end) // 2
            node.left = self.build_tree(start, mid)
            node.right = self.build_tree(mid + 1, end)
            node.sum = node.left.sum + node.right.sum
        return node
    
    def update(self, node, index, val):
        if node.start == node.end:
            node.sum = val
        else:
            mid = (node.start + node.end) // 2
            if index <= mid:
                self.update(node.left, index, val)
            else:
                self.update(node.right, index, val)
            node.sum = node.left.sum + node.right.sum
    
    def range_sum(self, node, start, end):
        if start > end:
            return 0
        if node.start == start and node.end == end:
            return node.sum
        mid = (node.start + node.end) // 2
        if end <= mid:
            return self.range_sum(node.left, start, end)
        elif start > mid:
            return self.range_sum(node.right, start, end)
        else:
            return self.range_sum(node.left, start, mid) + self.range_sum(node.right, mid + 1, end)


class NumArray:
    
    def __init__(self, nums: List[int]):
        self.st = SegementTree(nums)
    
    def update(self, index: int, val: int) -> None:
        self.st.update(self.st.root, index, val)
    
    def sumRange(self, left: int, right: int) -> int:
        return self.st.range_sum(self.st.root, left, right)

315. 计算右侧小于当前元素的个数(Hard)

image.png

Solution:

  • 只求count的话,可以转化为先bucket sort再做rangeSum(消除元素数量变化的影响)

Code:

class Solution:
    def countSmaller(self, nums: List[int]) -> List[int]:
        nums = [num + 10001 for num in nums]
        n = len(nums)
        exist = [0] * 20002
        st = SegementTree([0] * 20002)
        exist[nums[-1]] += 1
        st.update(st.root, nums[-1], 1)
        res = [0] * n
        for i in range(n - 2, -1, -1):
            exist[nums[i]] += 1
            st.update(st.root, nums[i], exist[nums[i]])  # 统计新的数字
            res[i] = st.range_sum(st.root, 0, nums[i] - 1)  # 找0~nums[i]-1范围内目前有几个数
        return res


Summary

  • prefixSum / BinaryIndexTree / SegmentTree 都可以用来处理rangeQuery系列问题
    • 一般来讲,凡是可以使用树状数组解决的问题,使用线段树也可以解决;但是线段树能够解决的问题树状数组未必能够解决 (例如求区间最大/小值)
  • 什么情况下,无法使用线段树?
    • 如果我们删除或者增加区间中的元素,那么区间的大小将发生变化,此时是无法使用线段树解决这种问题的。(如果只求count可以转化为bucket sort再做rangeSum就没有元素的增加减少了)