数据结构扩展(三) —— 树状树组

388 阅读2分钟

数据结构

  1. 假设对于数组arr[],有一个BIT[]
    • update(): update BIT for operation arr[idx] += val in O(logN)
      • ❗️❗️注意!!不是直接令arr[idx] = val ❌
    • getSum(): return sum(arr[0,idx]) in O(logN)
  2. 先init树状树组BIT[]中每一位为0
  3. 再call update()把arr中的每一位元素都插入

image.png

  • parent -> child: add 1 in binary representation
  • child -> parent: 删去rightmost的1,置为0

sum 和 update 的逻辑

image.png

image.png

image.png

  • extract last set bit: x & (-x)
  • remove it: x - (x & (-x))

Implement getSum(x) -> 拆分成二进制加法的形式:

  • ie:13 = 8+4+1 => sum(13)=range(1,8)+range(9,12)+range(13,13)
def get_sum(bit_arr, idx):
    res = 0
    while idx:
        res += bit_arr[idx]
        idx -= idx & -idx
    return res

Implement update(idx, val)

def update(bit_arr, idx, val):
    while idx < len(bit_arr):
        bit_arr[idx] += val
        idx += idx & -idx

❤️ BIT模版

注意:bit_arr[0]为dummy,从第1位开始❗️❗️

class BIT:
    def __init__(self, n):
        self.bit_arr = [0] * (n + 1)  # 第0位为dummy,从第1位开始
    
    def update(self, i, val):
        i += 1
        while i < len(self.bit_arr):
            self.bit_arr[i] += val
            i += i & -i
    
    def get_sum(self, i):
        res = 0
        i += 1
        while i > 0:
            res += self.bit_arr[i]
            i -= i & -i
        return res

树状树组 VS. 线段树

  • 线段树:可以维护对一个区间查询区间修改
  • 树状数组:是线段树的「阉割版」,经常用来区间查询,但修改只能进行单点修改
    • 经过改造之后可以区间修改,但线段树本身就可以支持区间修改

    • 使用树状数组的原因是因为树状数组比较好写



题目

1649. 通过指令创建有序数组(Hard)

image.png

Solution:

  • 转化为bucket sort

Code:

class Solution:
    def createSortedArray(self, instructions: List[int]) -> int:
        max_num = max(instructions)
        bit = BIT(max_num + 1)
        res = 0
        for num in instructions:
            left, right = bit.get_sum(num - 1), bit.get_sum(max_num) - bit.get_sum(num - 1)
            res = (res + min(left, right)) % (10 ** 9 + 7)
            bit.update(num, 1)
        return res

307. 区域和检索 - 数组可修改(Median)

image.png

Solution:

  • 照抄BIT模版,略

Code:

class NumArray:
    def __init__(self, nums: List[int]):
        self.nums = nums
        self.bit = BIT(len(nums))
        for i, n in enumerate(nums):
            self.bit.update(i, n)
    
    def update(self, index: int, val: int) -> None:
        diff = val - self.nums[index]
        self.nums[index] = val
        self.bit.update(index, diff)
    
    def sumRange(self, left: int, right: int) -> int:
        return self.bit.get_sum(right) - self.bit.get_sum(left - 1)