算法学习记录--线段树1(单点修改)

119 阅读5分钟

线段树是算法竞赛中常用的的高级数据结构,需要我们熟练掌握并合理使用它。
今天简单记录一下我对单点修改版本的理解并用例题来加以解释我也是刚学的线段树

线段树是一种高级数据结构,也是一种二叉树,时间复杂度为O(logn)。它能够高效的处理区间修改查询等问题。学线段树,就要学习它如何组织数据,然后高效的进行数据查询,修改等操作。

一般情况下会有三个函数:
  • build()用来建树
  • modify()用来修改区间或某一点的值
  • sum()用来求某一区间的和

给定一个长度为 n 的数组,有 q 次操作,操作有两种:

  • 1:给你一个下标i和一个数x,需要你把数组中第i个数加上x;

  • 2:给你一个范围:l 和 r ,让你求出数组中第l个数到第r个总和是多少。

首先,有的朋友一定会疑惑,求区间的和,为什么不用前缀和呢,预处理的时间复杂度为O(n),求区间和的时间复杂度为O(1)啊?可是,如果多次对区间进行修改,数组就变了,需要重新对前缀数组进行处理,那么单次询问的时间复杂度就变为O(n),对比来看,还是线段树功能更强大。
线段树中,每一个叶子节点都是给定的数组的元素,每个节点都有一段自己所管理的区间(范围),下图给出实例。 数组a及其下标给出,然后构造线段树(黑色为值val,红色为管理范围,蓝色是节点编号),然后每次 (l+r)/2 将区间分开,结果得到的二叉树如图,且叶子节点的管理范围:l==rd650851df6feffe72bba410a271292e.png

下面,我将用代码来实现这三个函数。
首先先说明一下结点的数据类型,这里用结构体

struct {
	int val;
	int left, right;//左右孩子的下标
	int l, r;//管理的范围
}node[N];

然后,就是建树的操作

void build(int pos)//pos是当前节点编号
{
	//出口,该节点管理的左右端点相同,叶子节点
	if (node[pos].l == node[pos].r)
	{
		node[pos].val = a[node[pos].l];
		return;
	}
	//不同
	int mid = (node[pos].l + node[pos].r) / 2;
	//左右孩子下标
	int left = ++idx, right = ++idx;
	node[pos].left = left;
	node[pos].right = right;
	//左右孩子的范围[l,r]
	node[left].l = node[pos].l;
	node[left].r = mid;

	node[right].l = mid + 1;
	node[right].r = node[pos].r;
	//递归到下一层
	build(left);
	build(right);

	//val的值:左右孩子的val 的和
	node[pos].val = node[left].val + node[right].val;
}

接着是修改区间某点

void modify(int pos,int l,int x)
{
        //找到要改的节点
	if (node[pos].l == l && node[pos].r == l)
	{
		node[pos].val += x;
		return;
	}
	int mid = (node[pos].l + node[pos].r) / 2;
        //要改的点的范围小于mid,到左侧找,反之去右侧找
	if (l <= mid) modify(node[pos].left, l, x);
	else modify(node[pos].right, l, x);

	//叶子结点的值被修改,它的所有祖先节点都需要修改
	int left = node[pos].left, right = node[pos].right;
	node[pos].val = node[left].val + node[right].val;
}

还有就是求区间和操作

int sum(int pos, int l, int r)
{
	//该节点是我们要的范围
	if (node[pos].l == l && node[pos].r == r)
	{
		return node[pos].val;
	}
	int mid = (node[pos].l + node[pos].r) / 2;
	//都在左边
	if (r <= mid) return sum(node[pos].left, l, r);
	else
	{
		//都在右边
		if (l > mid) return sum(node[pos].right, l, r);
		else
		{
			//既在左又在右
			int x = sum(node[pos].left, l, mid);
			int y = sum(node[pos].right, mid+1, r);
			return x + y;
		}
	}
}

最后,以例题为例,帮助大家理解

P3374 【模板】树状数组 1

AC代码

#include<iostream>
#include<algorithm>
#include<cstring>
#include<string>
#include<vector> 
#include<map>
#include<queue>
#include<cmath>
using namespace std;
#define int long long
#define endl '\n'
typedef pair<int, int> PII;
typedef long long ll;
typedef unsigned long long ull;
typedef long double ld;

const int N = 1000010;
struct {
	int val;
	int left, right;//左右孩子的下标
	int l, r;//管理的范围
}node[N];

int a[N], idx = 1; // idx是建树过程中,用于给各个树打上序号

//建树
void build(int pos)
{
	//出口,该节点管理的左右端点相同,叶子节点
	if (node[pos].l == node[pos].r)
	{
		node[pos].val = a[node[pos].l];
		return;
	}
	//不同
	int mid = (node[pos].l + node[pos].r) / 2;
	//左右孩子下标
	int left = ++idx, right = ++idx;
	node[pos].left = left;
	node[pos].right = right;
	//左右孩子的范围[l,r]
	node[left].l = node[pos].l;
	node[left].r = mid;

	node[right].l = mid + 1;
	node[right].r = node[pos].r;
	//递归到下一层
	build(left);
	build(right);

	//val的值:左右孩子的val 的和
	node[pos].val = node[left].val + node[right].val;
}

//修改单点
void modify(int pos,int l,int x)
{
	if (node[pos].l == l && node[pos].r == l)
	{
		node[pos].val += x;
		return;
	}
	int mid = (node[pos].l + node[pos].r) / 2;
	if (l <= mid) modify(node[pos].left, l, x);
	else modify(node[pos].right, l, x);

	//叶子结点的值被修改,它的所有祖先节点都需要修改
	int left = node[pos].left, right = node[pos].right;
	node[pos].val = node[left].val + node[right].val;
}

// 求和
int sum(int pos, int l, int r)
{
	//该节点是我们要的范围
	if (node[pos].l == l && node[pos].r == r)
	{
		return node[pos].val;
	}
	int mid = (node[pos].l + node[pos].r) / 2;
	//都在左边
	if (r <= mid) return sum(node[pos].left, l, r);
	else
	{
		//都在右边
		if (l > mid) return sum(node[pos].right, l, r);
		else
		{
			//既在左又在右
			int x = sum(node[pos].left, l, mid);
			int y = sum(node[pos].right, mid+1, r);
			return x + y;
		}
	}
}
signed main()
{
	ios_base::sync_with_stdio(false);
	cin.tie(0);
	cout.tie(0);
	int n, q, op, x, y;
	cin >> n >> q;
	for (int i = 1; i <= n; i++) cin >> a[i];
	node[1].l = 1, node[1].r = n;
	build(1);
	for (int i = 1; i <= q; i++)
	{
		cin >> op >> x >> y;
		if (op == 1) modify(1, x, y);
		else cout << sum(1, x, y) << endl;
	}
	return 0;
}