首先线段树是一颗二叉树,是二叉树就可以使用数组存储,节点i维护区间[l, r],节点2*i+1维护区间[l, (l+r)>>1],节点2*i+2维护区间[(l+r)>>1+1, r]。
存储二叉树的数组长度为原数组长度的2倍,考虑叶子节点访问不越界,一般将存储二叉树的数组长度开辟为原数组长度的4倍。这里假设树的节点存储其维护区间的区间元素和。
def __init__(self, nums: List[int]):
self.N = len(nums)
self.T = [0]*4*self.N
def inner(i, l, r):
if l == r:
self.T[i] = nums[l]
return
m = (l + r) >> 1
inner(i*2+1, l, m)
inner(i*2+2, m+1, r)
self.T[i] = self.T[i*2+1] + self.T[i*2+2]
inner(0, 0, self.N-1)
线段树上的叶子节点只维护一个元素,若想更新叶子节点的值,也需要相应的更新其祖先节点。
def update(self, i: int, val: int) -> None:
def inner(j, l, r):
if l == r:
t = self.T[j]
self.T[j] = val
return t
m = (l + r) >> 1
if i <= m:
t = inner(j*2+1, l, m)
else:
t = inner(j*2+2, m+1, r)
self.T[j] += val - t
return t
inner(0, 0, self.N-1)
线段树的优点体现在区间查询,这里查询的是区间元素和。
def sumRange(self, i: int, j: int) -> int:
def inner(k, i, j, l, r):
if i == l and j == r:
return self.T[k]
t, m = 0, (l + r) >> 1
if j <= m:
t += inner(k*2+1, i, j, l, m)
elif i > m:
t += inner(k*2+2, i, j, m+1, r)
else:
t += inner(k*2+1, i, m, l, m) + inner(k*2+2, m+1, j, m+1, r)
return t
return inner(0, i, j, 0, self.N-1)
附上一道线段树题目:leetcode 307.区域和检索 - 数组可修改