线段树入门

98 阅读4分钟

OIWIKI 过一下线段树概念。

单点修改,区间求和 过一下模版。

#include<bits/stdc++.h>
using namespace std;

#define lc u << 1
#define rc u << 1 | 1

constexpr int N = 1e5 + 10;

int n, m;
int w[N];

struct node{
	int l, r;
	int sum;
}tr[N * 4];

void pushup(int u){
	tr[u].sum = tr[lc].sum + tr[rc].sum;
}

void build(int u, int l, int r){
	if(l == r) {
	    tr[u] = {l, r, w[r]};
	}
	else{
		tr[u] = {l, r};
		int mid = l + r >> 1;
		build(lc, l, mid), build(rc, mid + 1, r);
		pushup(u);
	}
}

void modify(int u, int pos, int v){
	if(tr[u].l == tr[u].r) {
	    tr[u].sum += v;
	}
	else{
		int mid = tr[u].l + tr[u].r >> 1;
		if(pos <= mid) modify(lc, pos, v);
		else modify(rc, pos, v);
		pushup(u);
	}
}

int query(int u, int l, int r){
	if(tr[u].l >= l && tr[u].r <= r) return tr[u].sum;
	int mid = tr[u].l + tr[u].r >> 1;
	int sum = 0;
	if(l <= mid) sum = query(lc, l, r);
	if(r > mid) sum += query(rc, l, r);
	return sum;	
}

int main(){
	scanf("%d%d", &n, &m);
	for(int i = 1; i <= n; i ++){
        scanf("%d", &w[i]);
    }
	build(1, 1, n);
	
	for(int i = 1; i <= m; i ++){
		int k, a, b;
		scanf("%d%d%d", &k, &a, &b);
		if(!k) printf("%d\n", query(1, a ,b));
		else modify(1, a, b);
	}

	return 0;
} 

查询区间最大值 多维护一个最值的信息。

#include<bits/stdc++.h>
using namespace std;
#define lc u << 1
#define rc u << 1 | 1

constexpr int N = 1e5 + 10;

int n, m;
int w[N];

struct node{
    int l, r;
    int sum, mx;
}tr[N * 4];

void pushup(int u){
    tr[u].sum = tr[lc].sum + tr[rc].sum;
    tr[u].mx = max(tr[lc].mx, tr[rc].mx);
}

void build(int u, int l, int r){
    if(l == r) {
        tr[u] = {l, r, w[r], w[r]};
    }
    else{
        tr[u] = {l, r};
        int mid = l + r >> 1;
        build(lc, l, mid);
        build(rc, mid + 1, r);
        pushup(u);
    }
}

void modify(int u, int pos, int v){
    if(tr[u].l == tr[u].r) {
        tr[u].sum += v;
        tr[u].mx += v;
    }
    else{
        int mid = tr[u].l + tr[u].r >> 1;
        if(pos <= mid) modify(lc, pos, v);
        else modify(rc, pos, v);
        pushup(u);
    }
}

int query_sum(int u, int l, int r){
    if(tr[u].l >= l && tr[u].r <= r) return tr[u].sum;
    int mid = tr[u].l + tr[u].r >> 1;
    int sum = 0;
    if(l <= mid) sum += query_sum(lc, l, r);
    if(r > mid) sum += query_sum(rc, l, r);
    return sum; 
}

int query_max(int u, int l, int r){
    if(tr[u].l >= l && tr[u].r <= r) return tr[u].mx;
    int mid = tr[u].l + tr[u].r >> 1;
    int res = INT_MIN;
    if(l <= mid) res = max(res, query_max(lc, l, r));
    if(r > mid) res = max(res, query_max(rc, l, r));
    return res;
}

int main(){
    scanf("%d%d", &n, &m);
    for(int i = 1; i <= n; i ++){
        scanf("%d", &w[i]);
    }
    build(1, 1, n);
    
    for(int i = 1; i <= m; i ++){
        int a, b;
        scanf("%d%d", &a, &b);
        printf("%d\n", query_max(1, a, b));
    }

    return 0;
}

区间加+区间求和 讲解懒标记。

#include<bits/stdc++.h>
using namespace std;
#define int long long

#define lc u << 1
#define rc u << 1 | 1

constexpr int N = 1e5 + 10;
int w[N];

struct treeNode{
	int l, r, sum, add;
}tr[N * 4];

// 向上更新
void pushup(int u){ 
	tr[u].sum = tr[lc].sum + tr[rc].sum;
}
// 向下更新
void pushdown(int u){ 
	if(tr[u].add){
		tr[lc].sum += (tr[lc].r - tr[lc].l + 1) * tr[u].add;
		tr[rc].sum += (tr[rc].r - tr[rc].l + 1) * tr[u].add;
		tr[lc].add += tr[u].add;
		tr[rc].add += tr[u].add;
		tr[u].add = 0;
	}
}
// 建树
void buildTree(int u, int l, int r){
	tr[u] = {l, r, w[l], 0};
	if(l == r) return ;
	int mid = l + r >> 1;
	buildTree(lc, l, mid);
	buildTree(rc, mid + 1, r);
	pushup(u);
}
// 区间加减
void SegmentAdd(int u, int l, int r, int v){ // w[l, r] += v
	if(l <= tr[u].l && r >= tr[u].r){
		tr[u].sum += (tr[u].r - tr[u].l + 1) * v;
		tr[u].add += v;
		return ;
	}
	// 不覆盖, 分裂线段
	pushdown(u);
	int mid = tr[u].l + tr[u].r >> 1;
	if(l <= mid) SegmentAdd(lc, l, r, v);
	if(r > mid) SegmentAdd(rc, l, r, v);
	pushup(u);
}
// 区间求和
int AskSegmentSum(int u, int l, int r){
	if(l <= tr[u].l && r >= tr[u].r){
		return tr[u].sum;
	}
	int mid = tr[u]. l + tr[u].r >> 1;
	pushdown(u);
	int sum = 0;
	if(l <= mid) sum += AskSegmentSum(lc, l, r);
	if(r > mid) sum += AskSegmentSum(rc, l, r);
	return sum;
}
void solve(){
	int n, m;
	cin >> n >> m;
	for(int i = 1; i <= n; i ++){
		cin >> w[i];
    }
    buildTree(1, 1, n);
    while(m --){
    	int opt, l, r, v;
    	cin >> opt >> l >> r;
    	if(opt == 1){
    		cin >> v;
    		SegmentAdd(1, l, r, v);
    	}
    	else{
    		cout << AskSegmentSum(1, l, r) << '\n';
    	}
    }
}

signed main(){
	ios::sync_with_stdio(false);
	cin.tie(0), cout.tie(0); 
	solve();
	return 0;
}

I hate it 练习题。

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<climits>
using namespace std;

#define rep(i, a, b) for(int i = a;i <= b;i ++)
typedef long long ll;

int n, m;
int const N = 2e5 + 10;
int w[N];

struct node{
	int l, r;
	int v; // 最大值 
}tr[N * 4];

void pushup(int u){
	tr[u].v = max(tr[u << 1].v, tr[u << 1 | 1].v);
}

void build(int u, int l, int r){
	if(l == r) tr[u] = {l, r, w[l]};
	else{
		tr[u] = {l, r};
		int mid = l + r >> 1;
		build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
		pushup(u);
	}
}

void modify(int u, int pos, int val){
	if(tr[u].l == tr[u].r) tr[u].v = max(tr[u].v, val);
	else{
		int mid = tr[u].l + tr[u].r >> 1;
		if(pos <= mid) modify(u << 1, pos, val);
		else modify(u << 1 | 1, pos, val);
		pushup(u);
	}	 
}

int query(int u, int l, int r){
	if(tr[u].l >= l && tr[u].r <= r) return tr[u].v;
	int mid = tr[u].l + tr[u].r >> 1;
	int res = INT_MIN;
	if(l <= mid) res = query(u << 1, l, r);
	if(r > mid) res = max(res, query(u << 1 | 1, l, r));
	return res;
}

int main(){
	scanf("%d%d", &n, &m);
	rep(i, 1, n) scanf("%d", &w[i]);
	build(1, 1, n);
	
	while(m --){
		int a, b;
		char op[2];
		scanf("%s%d%d", op, &a, &b);
		if(*op == 'Q') cout << query(1, a, b) << '\n';
		else modify(1, a, b);
	}

	return 0;
}