线段树之单点更新

328 阅读8分钟

引入

线段树是算法竞赛中常用的用来维护 区间信息 的数据结构。

线段树可以在 O(logN)O(\log N) 的时间复杂度内实现单点修改、区间修改、区间查询(区间求和,求区间最大值,求区间最小值)等操作。

线段树

线段树的基本结构与建树

过程

image.png

image.png

代码实现

线段树一般开 44 倍空间。

  • 创建线段树

递归构建线段树,直到来到叶子结点,在回溯时更新父节点信息。

#define lc u << 1
#define rc u << 1 | 1
constexpr int N = 1e5 + 10;

struct node{
	int l, r;
	int sum;
}tr[N * 4];

// 更新父节点信息
void pushup(int u){
	tr[u].sum = tr[lc].sum + tr[rc].sum;
}

void build(int u, int l, int r){
	tr[u] = {l, r, w[r]};
	if(l == r) return ;
	int mid = l + r >> 1;
	build(lc, l, mid), build(rc, mid + 1, r);
	pushup(u);
}
  • 维护线段树信息

对于单点修改,递归寻找,找到叶子结点后修改,回溯时维护父节点的信息。

对于区间求和,假设来到线段 u:

image.png

如果待查询区间完全包含线段 [l, r], 直接加上这部分和。 否则, 在左右儿子中递归查询。



// 单点修改
void modify(int u, int pos, int v){
	if(tr[u].l == tr[u].r) {
	    tr[u].sum += v;
	}
	else{
		int mid = tr[u].l + tr[u].r >> 1;
		if(pos <= mid) modify(lc, pos, v);
		else modify(rc, pos, v);
		pushup(u);
	}
}

// 区间求和
int query(int u, int l, int r){
	if(tr[u].l >= l && tr[u].r <= r) {
		return tr[u].sum;
	}
	int mid = tr[u].l + tr[u].r >> 1;
	int sum = 0;
	if(l <= mid) sum = query(lc, l, r);
	if(r > mid) sum += query(rc, l, r);
	return sum;	
}

敌兵布阵

模版题

#include<bits/stdc++.h>
using namespace std;

#define lc u << 1
#define rc u << 1 | 1
constexpr int N = 1e5 + 10;

int n;
int w[N];

struct node{
	int l, r;
	int sum;
}tr[N * 4];

void pushup(int u) {
	tr[u].sum = tr[lc].sum + tr[rc].sum;
}

void build(int u, int l, int r){
	tr[u] = {l, r, w[r]};
	if(l == r) return ;
	int mid = l + r >> 1;
	build(lc, l, mid), build(rc, mid + 1, r);
	pushup(u);
}

// 单点修改
void modify(int u, int pos, int v){
	if(tr[u].l == tr[u].r) {
	    tr[u].sum += v;
	}
	else{
		int mid = tr[u].l + tr[u].r >> 1;
		if(pos <= mid) modify(lc, pos, v);
		else modify(rc, pos, v);
		pushup(u);
	}
}

// 区间求和
int query(int u, int l, int r){
	if(tr[u].l >= l && tr[u].r <= r) {
		return tr[u].sum;
	}
	int mid = tr[u].l + tr[u].r >> 1;
	int sum = 0;
	if(l <= mid) sum = query(lc, l, r);
	if(r > mid) sum += query(rc, l, r);
	return sum;	
}

int main(){
	ios::sync_with_stdio(false);
	cin.tie(0), cout.tie(0); 

	int T, cnt = 0;
	cin >> T;
	while(T --) {
		cout << "Case " << ++ cnt << ":\n";
		cin >> n;
		for(int i = 1; i <= n; i ++){
	        cin >> w[i];
	    }
		build(1, 1, n);

		string op;
		int a, b;
		while(cin >> op, op[0] != 'E'){
			cin >> a >> b;
			if(op[0] == 'Q') {
				cout << query(1, a, b) << '\n';
			}
			else if(op[0] == 'A') {
				modify(1, a, b);
			}	
			else {
				modify(1, a, -b);
			}
		}
	}
	

	return 0;
} 

I Hate it

模版题

#include<iostream>
#include<algorithm>
#include<cstring>
#include<climits>
using namespace std;

typedef long long ll;

int n, m;
int const N = 2e5 + 10;
int w[N];

struct node {
    int l, r;
    int v; // 最大值 
} tr[N * 4];

void pushup(int u) {
    tr[u].v = max(tr[u << 1].v, tr[u << 1 | 1].v);
}

void build(int u, int l, int r) {
    if (l == r) tr[u] = {l, r, w[l]};
    else {
        tr[u] = {l, r};
        int mid = (l + r) >> 1;
        build(u << 1, l, mid);
        build(u << 1 | 1, mid + 1, r);
        pushup(u);
    }
}

void modify(int u, int pos, int val) {
    if (tr[u].l == tr[u].r) tr[u].v = max(tr[u].v, val);
    else {
        int mid = (tr[u].l + tr[u].r) >> 1;
        if (pos <= mid) modify(u << 1, pos, val);
        else modify(u << 1 | 1, pos, val);
        pushup(u);
    }	 
}

int query(int u, int l, int r) {
    if (tr[u].l >= l && tr[u].r <= r) return tr[u].v;
    int mid = (tr[u].l + tr[u].r) >> 1;
    int res = INT_MIN;
    if (l <= mid) res = query(u << 1, l, r);
    if (r > mid) res = max(res, query(u << 1 | 1, l, r));
    return res;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(0);
    
    while(cin >> n >> m) {
    	for (int i = 1; i <= n; i++) cin >> w[i];
	    build(1, 1, n);

	    while (m--) {
	        int a, b;
	        char op;
	        cin >> op >> a >> b;
	        if (op == 'Q') cout << query(1, a, b) << '\n';
	        else modify(1, a, b);
	    }
    }	
    

    return 0;
}

Minimum Inversion Number

先用线段树求一遍逆序对。

对于开头数 x, 后边有 x-1 个比自己小的数, 移到末尾之后左边有 n-x 个比自己大的。

#include<bits/stdc++.h>
using namespace std;

#define lc u << 1
#define rc u << 1 | 1
constexpr int N = 1e5 + 10;

int n, a[N];

struct node{
	int l, r;
	int sum;
}tr[N * 4];

void pushup(int u) {
	tr[u].sum = tr[lc].sum + tr[rc].sum;
}

void build(int u, int l, int r){
	tr[u] = {l, r, 0};
	if(l == r) return ;
	int mid = l + r >> 1;
	build(lc, l, mid), build(rc, mid + 1, r);
	pushup(u);
}

// 单点修改
void modify(int u, int pos, int v){
	if(tr[u].l == tr[u].r) {
	    tr[u].sum += v;
	}
	else{
		int mid = tr[u].l + tr[u].r >> 1;
		if(pos <= mid) modify(lc, pos, v);
		else modify(rc, pos, v);
		pushup(u);
	}
}

// 区间求和
int query(int u, int l, int r){
	if(tr[u].l >= l && tr[u].r <= r) {
		return tr[u].sum;
	}
	int mid = tr[u].l + tr[u].r >> 1;
	int sum = 0;
	if(l <= mid) sum = query(lc, l, r);
	if(r > mid) sum += query(rc, l, r);
	return sum;	
}

int main(){
	ios::sync_with_stdio(false);
	cin.tie(0), cout.tie(0); 

	while(cin >> n) {
		int res = 1E9;
		for(int i = 1; i <= n; i ++){
			cin >> a[i];
			++ a[i];
    	}
    	build(1, 1, n);
    	int rever = 0;
    	for(int i = 1; i <= n; i ++){
    		if(a[i] != n) {
    			rever += query(1, a[i] + 1, n);
    		}
    		modify(1, a[i], 1);
	    }
	    for(int i = 1; i <= n; i ++){
	    	rever = rever + (n - a[i]) - (a[i] - 1);
	    	res = min(res, rever);
	    }
	    cout << res << '\n';
	}
	

	return 0;
} 

Tunnel Warfare

这题题意表述不清 :

  • 这题是多组数据

  • 注意同一个村庄可能被摧毁多次

#include<bits/stdc++.h>
using namespace std;

#define lc u << 1
#define rc u << 1 | 1
constexpr int N = 1e5 + 10;

int n, m, a[N];

struct node{
	int l, r;
	int sum;
}tr[N * 4];

void pushup(int u) {
	tr[u].sum = tr[lc].sum + tr[rc].sum;
}

void build(int u, int l, int r){
	tr[u] = {l, r, 1};
	if(l == r) return ;
	int mid = l + r >> 1;
	build(lc, l, mid), build(rc, mid + 1, r);
	pushup(u);
}

// 单点修改
void modify(int u, int pos, int v){
	if(tr[u].l == tr[u].r) {
	    tr[u].sum = v;
	}
	else{
		int mid = tr[u].l + tr[u].r >> 1;
		if(pos <= mid) modify(lc, pos, v);
		else modify(rc, pos, v);
		pushup(u);
	}
}

// 区间求和
int query(int u, int l, int r){
	if(tr[u].l >= l && tr[u].r <= r) {
		return tr[u].sum;
	}
	int mid = tr[u].l + tr[u].r >> 1;
	int sum = 0;
	if(l <= mid) sum = query(lc, l, r);
	if(r > mid) sum += query(rc, l, r);
	return sum;	
}

int main(){
	ios::sync_with_stdio(false);
	cin.tie(0), cout.tie(0); 

	while (cin >> n >> m) {
		stack<int> des;
		build(1, 1, n);
		while(m --) {
			char op;
			int x;
			cin >> op;
			if(op == 'D') {
				cin >> x;
				modify(1, x, 0);
				des.push(x);
			}
			else if(op == 'R') {
				modify(1, des.top(), 1);
				des.pop();
			}
			else {
				int res = 0;
				cin >> x;
				if(query(1, x, x) == 0) {
					cout << "0\n";
					continue ;
				}
				res = 1; // 至少包含自己
				if(x != n) {
					int l = x, r = n;
					while(l < r) {
						int mid = l + r + 1 >> 1;
						if(query(1, x, mid) == mid - x + 1) l = mid;
						else r = mid - 1;
					}
					res += l - x;
				}
				// cout << "fsRes: " << res << '\n';
				if(x != 1) {
					int l = 1, r = x;
					while(l < r) {
						int mid = l + r >> 1;
						if(query(1, mid, x) == x - mid + 1) r = mid;
						else l = mid + 1;
					}
					res += x - l;
				}
				cout << res << '\n';
			}

		}	
	}

	return 0;
} 

Billboard

  • 维护区间最大值
  • 注意 h 非常大
#include<bits/stdc++.h>
using namespace std;

#define lc u << 1
#define rc u << 1 | 1
constexpr int N = 2e5 + 10;

int n, h, w, a[N];

struct node{
	int l, r;
	int sum, mx;
}tr[N * 4];

void pushup(int u) {
	tr[u].sum = tr[lc].sum + tr[rc].sum;
	tr[u].mx = max(tr[lc].mx, tr[rc].mx);
}

void build(int u, int l, int r){
	tr[u] = {l, r, w, w};
	if(l == r) return ;
	int mid = l + r >> 1;
	build(lc, l, mid), build(rc, mid + 1, r);
	pushup(u);
}

// 单点修改
void modify(int u, int pos, int v){
	if(tr[u].l == tr[u].r) {
	    tr[u].sum += v;
	    tr[u].mx += v;
	}
	else{
		int mid = tr[u].l + tr[u].r >> 1;
		if(pos <= mid) modify(lc, pos, v);
		else modify(rc, pos, v);
		pushup(u);
	}
}

// 区间求和
int query(int u, int l, int r){
	if(tr[u].l >= l && tr[u].r <= r) {
		return tr[u].sum;
	}
	int mid = tr[u].l + tr[u].r >> 1;
	int sum = 0;
	if(l <= mid) sum = query(lc, l, r);
	if(r > mid) sum += query(rc, l, r);
	return sum;	
}

int askMax(int u, int l, int r) {
	if(tr[u].l >= l && tr[u].r <= r) {
		return tr[u].mx;
	}
	int mid = tr[u].l + tr[u].r >> 1;
	int mx = -2E9;
	if(l <= mid) mx = max(mx, askMax(lc, l, r));
	if(r > mid) mx = max(mx, askMax(rc, l, r));
	return mx;	
}

int main(){
	ios::sync_with_stdio(false);
	cin.tie(0), cout.tie(0); 

	while (cin >> h >> w >> n) {
		if(h > n) h = n;
		build(1, 1, h);
		for(int i = 1; i <= n; i ++){
			int x;
			cin >> x;
			if(askMax(1, 1, h) < x) {
				cout << "-1\n";
				continue ;
			}
			int l = 1, r = h;
			while(l < r) {
				int mid = l + r >> 1;
				if(askMax(1, 1, mid) >= x) r = mid;
				else l = mid + 1;
			}
			modify(1, l, -x);
			cout << l << '\n';
	    }


	}

	return 0;
} 

Coder

线段树的每条线段维护一个集合。

提前离散化建有序线段树。

查询 sum 的时候可以合并集合

// tr[u].sum[i] = 
// 	tr[lc].sum[i] + tr[rc]

// (idx + lnum) % 5 = i
// idx = (i - lnum) % 5

// 1 2 3 | 1 2 3

// Ac : 3

// (1 - 3) 3
#include<bits/stdc++.h>
using namespace std;

#define lc u << 1
#define rc u << 1 | 1
constexpr int N = 1E5 + 10;
int n, a[N], tmp[N], x[N], s;
string op[N];

struct node {
	int l, r, cnt;
	long long sum[5];
}tr[N << 2];

void pushup(int u) {
	tr[u].cnt = tr[lc].cnt + tr[rc].cnt;
	for(int i = 0; i < 5; i ++) {
		int idx = (i - tr[lc].cnt % 5) % 5;
		if(idx < 0) idx += 5;
		tr[u].sum[i] = tr[lc].sum[i] + tr[rc].sum[idx];
	}
}

void build(int u, int l, int r) {
	tr[u] = {l, r, 0};
	memset(tr[u].sum, 0, sizeof tr[u].sum);
	if(l == r) return ;
	int mid = l + r >> 1;
	build(lc, l, mid);
	build(rc, mid + 1, r);
	pushup(u);
}

void update(int u, int p, int v) {
	if(tr[u].l == tr[u].r) {
		if(v < 0) tr[u].cnt --;
		else tr[u].cnt ++;
		tr[u].sum[0] += v;
		return ;
	}
	int mid = tr[u].l + tr[u].r >> 1;
	if(p <= mid) update(lc, p, v);
	else update(rc, p, v);
	pushup(u);
}

int main(){
	ios::sync_with_stdio(false);
	cin.tie(0), cout.tie(0);

	while(cin >> n) {
		s = 0;
		for(int i = 1; i <= n; i ++) {
			cin >> op[i];
			if(op[i][0] != 's') {
				cin >> x[i];
				if(op[i][0] == 'a') {
					a[++ s] = x[i];
					tmp[s] = x[i];
				}
			}
		}
		sort(tmp + 1, tmp + s + 1);
		int sz = unique(tmp + 1, tmp + s + 1) - tmp - 1;
		build(1, 1, sz);
		for(int i = 1; i <= n; i ++) {
			if(op[i][0] == 'a') {
				int idx = lower_bound(tmp + 1, tmp + sz + 1, x[i]) - tmp;
				update(1, idx, x[i]);
			}
			else if(op[i][0] == 'd') {
				int idx = lower_bound(tmp + 1, tmp + sz + 1, x[i]) - tmp;
				update(1, idx, -x[i]);
			}
			else {
				cout << tr[1].sum[2] << '\n';
			}
		}
	}

	return 0;
}