用途
给定一个长度为n的序列,需要:
- 频繁的求某个区间的最值
- 频繁的更新某个区间的部分/所有值
「线段树」可以解决这类需要维护区间信息的问题,可以在O(logN)的时间复杂度哪实现:
- 单点修改(
O(logN)) - 区间修改(*需要用到lazy propogation来优化到logN,完全不在面试范围内)
- 区间查询(
O(logN)区间求和/区间最大值/区间最小值/区间最小公倍数/区间最大公因数)
- 叶子节点存储输入的数组元素
- 每一个内部节点表示某些叶子节点的合并 (merge)
- 合并的方法可能会因问题而异。如上图所示,合并指的是某个节点之下的所有叶子节点的和(求区间最大最小的时候就不是求和了,而是找最大或者最小)
- 对于下标
i的节点,其左孩子为2*i+1,右孩子为2*i+2,父节点为floor((i - 1)/2)- 默认method为:
update():单点更新rangeSum()rangeMax()rangeMin()- *
rangeUpdate()
线段树的两种实现
❤️ 普通线段树 - tree based
class SegementTreeNode:
def __init__(self, start: int, end: int):
self.left = self.right = None # node的左右子节点
self.start, self.end = start, end # node的范围(闭区间)
self.sum = 0 # sum(nums[l,r])
class SegementTree:
def __init__(self, nums):
self.nums = nums
self.root = self.build_tree(0, len(nums) - 1)
def build_tree(self, start, end):
if start > end:
return None
node = SegementTreeNode(start, end)
if start == end:
node.sum = self.nums[start]
else:
mid = (start + end) // 2
node.left = self.build_tree(start, mid)
node.right = self.build_tree(mid + 1, end)
node.sum = node.left.sum + node.right.sum
return node
def update(self, node, index, val):
if node.start == node.end:
node.sum = val
else:
mid = (node.start + node.end) // 2
if index <= mid:
self.update(node.left, index, val)
else:
self.update(node.right, index, val)
node.sum = node.left.sum + node.right.sum
def range_sum(self, node, start, end):
if start > end:
return 0
if node.start == start and node.end == end:
return node.sum
mid = (node.start + node.end) // 2
if end <= mid:
return self.range_sum(node.left, start, end)
elif start > mid:
return self.range_sum(node.right, start, end)
else:
return self.range_sum(node.left, start, mid) + self.range_sum(node.right, mid + 1, end)
*ZKW线段树
PS:适合contest,不适合面试
class ZKWSegementTree:
def __init__(self, nums: List[int]):
self.n = len(nums)
self.st = [0] * (2 * self.n)
for i in range(self.n, self.n * 2): # leaf node是原始数组的值
self.st[i] = nums[i - self.n]
for i in range(self.n - 1, 0, -1): # parent = 两个child的sum
self.st[i] = self.st[i * 2] + self.st[i * 2 + 1]
def update(self, i: int, val: int) -> None:
diff = val - self.st[i + self.n]
i += self.n
while i > 0:
self.st[i] += diff
i //= 2
def rangeSum(self, l, r) -> int:
res = 0
l += self.n
r += self.n
while l <= r:
if l % 2 == 1: # st[l]是left child
res += self.st[l]
l += 1
if r % 2 == 0: # st[r]是right child
res += self.st[r]
r -= 1
l //= 2
r //= 2
return res
题目
307. 区域和检索 - 数组可修改(Median)
Solution:
- 线段树(tree-based)模版,略
Code:
class SegementTreeNode:
def __init__(self, start: int, end: int):
self.left = self.right = None # node的左右子节点
self.start, self.end = start, end # node的范围(闭区间)
self.sum = 0 # sum(nums[l,r])
class SegementTree:
def __init__(self, nums):
self.nums = nums
self.root = self.build_tree(0, len(nums) - 1)
def build_tree(self, start, end):
if start > end:
return None
node = SegementTreeNode(start, end)
if start == end:
node.sum = self.nums[start]
else:
mid = (start + end) // 2
node.left = self.build_tree(start, mid)
node.right = self.build_tree(mid + 1, end)
node.sum = node.left.sum + node.right.sum
return node
def update(self, node, index, val):
if node.start == node.end:
node.sum = val
else:
mid = (node.start + node.end) // 2
if index <= mid:
self.update(node.left, index, val)
else:
self.update(node.right, index, val)
node.sum = node.left.sum + node.right.sum
def range_sum(self, node, start, end):
if start > end:
return 0
if node.start == start and node.end == end:
return node.sum
mid = (node.start + node.end) // 2
if end <= mid:
return self.range_sum(node.left, start, end)
elif start > mid:
return self.range_sum(node.right, start, end)
else:
return self.range_sum(node.left, start, mid) + self.range_sum(node.right, mid + 1, end)
class NumArray:
def __init__(self, nums: List[int]):
self.st = SegementTree(nums)
def update(self, index: int, val: int) -> None:
self.st.update(self.st.root, index, val)
def sumRange(self, left: int, right: int) -> int:
return self.st.range_sum(self.st.root, left, right)
315. 计算右侧小于当前元素的个数(Hard)
Solution:
- 只求count的话,可以转化为先bucket sort再做rangeSum(消除元素数量变化的影响)
Code:
class Solution:
def countSmaller(self, nums: List[int]) -> List[int]:
nums = [num + 10001 for num in nums]
n = len(nums)
exist = [0] * 20002
st = SegementTree([0] * 20002)
exist[nums[-1]] += 1
st.update(st.root, nums[-1], 1)
res = [0] * n
for i in range(n - 2, -1, -1):
exist[nums[i]] += 1
st.update(st.root, nums[i], exist[nums[i]]) # 统计新的数字
res[i] = st.range_sum(st.root, 0, nums[i] - 1) # 找0~nums[i]-1范围内目前有几个数
return res
Summary
- prefixSum / BinaryIndexTree / SegmentTree 都可以用来处理rangeQuery系列问题
- 一般来讲,凡是可以使用树状数组解决的问题,使用线段树也可以解决;但是线段树能够解决的问题树状数组未必能够解决 (例如求区间最大/小值)
- 什么情况下,无法使用线段树?
- 如果我们删除或者增加区间中的元素,那么区间的大小将发生变化,此时是无法使用线段树解决这种问题的。(如果只求count可以转化为bucket sort再做rangeSum就没有元素的增加减少了)