引入
线段树是算法竞赛中常用的用来维护 区间信息 的数据结构。
线段树可以在 的时间复杂度内实现单点修改、区间修改、区间查询(区间求和,求区间最大值,求区间最小值)等操作。
线段树
线段树的基本结构与建树
过程
代码实现
线段树一般开 倍空间。
- 创建线段树
递归构建线段树,直到来到叶子结点,在回溯时更新父节点信息。
#define lc u << 1
#define rc u << 1 | 1
constexpr int N = 1e5 + 10;
struct node{
int l, r;
int sum;
}tr[N * 4];
// 更新父节点信息
void pushup(int u){
tr[u].sum = tr[lc].sum + tr[rc].sum;
}
void build(int u, int l, int r){
tr[u] = {l, r, w[r]};
if(l == r) return ;
int mid = l + r >> 1;
build(lc, l, mid), build(rc, mid + 1, r);
pushup(u);
}
- 维护线段树信息
对于单点修改,递归寻找,找到叶子结点后修改,回溯时维护父节点的信息。
对于区间求和,假设来到线段 u:
如果待查询区间完全包含线段 [l, r], 直接加上这部分和。
否则, 在左右儿子中递归查询。
// 单点修改
void modify(int u, int pos, int v){
if(tr[u].l == tr[u].r) {
tr[u].sum += v;
}
else{
int mid = tr[u].l + tr[u].r >> 1;
if(pos <= mid) modify(lc, pos, v);
else modify(rc, pos, v);
pushup(u);
}
}
// 区间求和
int query(int u, int l, int r){
if(tr[u].l >= l && tr[u].r <= r) {
return tr[u].sum;
}
int mid = tr[u].l + tr[u].r >> 1;
int sum = 0;
if(l <= mid) sum = query(lc, l, r);
if(r > mid) sum += query(rc, l, r);
return sum;
}
敌兵布阵
模版题
#include<bits/stdc++.h>
using namespace std;
#define lc u << 1
#define rc u << 1 | 1
constexpr int N = 1e5 + 10;
int n;
int w[N];
struct node{
int l, r;
int sum;
}tr[N * 4];
void pushup(int u) {
tr[u].sum = tr[lc].sum + tr[rc].sum;
}
void build(int u, int l, int r){
tr[u] = {l, r, w[r]};
if(l == r) return ;
int mid = l + r >> 1;
build(lc, l, mid), build(rc, mid + 1, r);
pushup(u);
}
// 单点修改
void modify(int u, int pos, int v){
if(tr[u].l == tr[u].r) {
tr[u].sum += v;
}
else{
int mid = tr[u].l + tr[u].r >> 1;
if(pos <= mid) modify(lc, pos, v);
else modify(rc, pos, v);
pushup(u);
}
}
// 区间求和
int query(int u, int l, int r){
if(tr[u].l >= l && tr[u].r <= r) {
return tr[u].sum;
}
int mid = tr[u].l + tr[u].r >> 1;
int sum = 0;
if(l <= mid) sum = query(lc, l, r);
if(r > mid) sum += query(rc, l, r);
return sum;
}
int main(){
ios::sync_with_stdio(false);
cin.tie(0), cout.tie(0);
int T, cnt = 0;
cin >> T;
while(T --) {
cout << "Case " << ++ cnt << ":\n";
cin >> n;
for(int i = 1; i <= n; i ++){
cin >> w[i];
}
build(1, 1, n);
string op;
int a, b;
while(cin >> op, op[0] != 'E'){
cin >> a >> b;
if(op[0] == 'Q') {
cout << query(1, a, b) << '\n';
}
else if(op[0] == 'A') {
modify(1, a, b);
}
else {
modify(1, a, -b);
}
}
}
return 0;
}
I Hate it
模版题
#include<iostream>
#include<algorithm>
#include<cstring>
#include<climits>
using namespace std;
typedef long long ll;
int n, m;
int const N = 2e5 + 10;
int w[N];
struct node {
int l, r;
int v; // 最大值
} tr[N * 4];
void pushup(int u) {
tr[u].v = max(tr[u << 1].v, tr[u << 1 | 1].v);
}
void build(int u, int l, int r) {
if (l == r) tr[u] = {l, r, w[l]};
else {
tr[u] = {l, r};
int mid = (l + r) >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
pushup(u);
}
}
void modify(int u, int pos, int val) {
if (tr[u].l == tr[u].r) tr[u].v = max(tr[u].v, val);
else {
int mid = (tr[u].l + tr[u].r) >> 1;
if (pos <= mid) modify(u << 1, pos, val);
else modify(u << 1 | 1, pos, val);
pushup(u);
}
}
int query(int u, int l, int r) {
if (tr[u].l >= l && tr[u].r <= r) return tr[u].v;
int mid = (tr[u].l + tr[u].r) >> 1;
int res = INT_MIN;
if (l <= mid) res = query(u << 1, l, r);
if (r > mid) res = max(res, query(u << 1 | 1, l, r));
return res;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0);
while(cin >> n >> m) {
for (int i = 1; i <= n; i++) cin >> w[i];
build(1, 1, n);
while (m--) {
int a, b;
char op;
cin >> op >> a >> b;
if (op == 'Q') cout << query(1, a, b) << '\n';
else modify(1, a, b);
}
}
return 0;
}
Minimum Inversion Number
先用线段树求一遍逆序对。
对于开头数 x, 后边有 x-1 个比自己小的数, 移到末尾之后左边有 n-x 个比自己大的。
#include<bits/stdc++.h>
using namespace std;
#define lc u << 1
#define rc u << 1 | 1
constexpr int N = 1e5 + 10;
int n, a[N];
struct node{
int l, r;
int sum;
}tr[N * 4];
void pushup(int u) {
tr[u].sum = tr[lc].sum + tr[rc].sum;
}
void build(int u, int l, int r){
tr[u] = {l, r, 0};
if(l == r) return ;
int mid = l + r >> 1;
build(lc, l, mid), build(rc, mid + 1, r);
pushup(u);
}
// 单点修改
void modify(int u, int pos, int v){
if(tr[u].l == tr[u].r) {
tr[u].sum += v;
}
else{
int mid = tr[u].l + tr[u].r >> 1;
if(pos <= mid) modify(lc, pos, v);
else modify(rc, pos, v);
pushup(u);
}
}
// 区间求和
int query(int u, int l, int r){
if(tr[u].l >= l && tr[u].r <= r) {
return tr[u].sum;
}
int mid = tr[u].l + tr[u].r >> 1;
int sum = 0;
if(l <= mid) sum = query(lc, l, r);
if(r > mid) sum += query(rc, l, r);
return sum;
}
int main(){
ios::sync_with_stdio(false);
cin.tie(0), cout.tie(0);
while(cin >> n) {
int res = 1E9;
for(int i = 1; i <= n; i ++){
cin >> a[i];
++ a[i];
}
build(1, 1, n);
int rever = 0;
for(int i = 1; i <= n; i ++){
if(a[i] != n) {
rever += query(1, a[i] + 1, n);
}
modify(1, a[i], 1);
}
for(int i = 1; i <= n; i ++){
rever = rever + (n - a[i]) - (a[i] - 1);
res = min(res, rever);
}
cout << res << '\n';
}
return 0;
}
Tunnel Warfare
这题题意表述不清 :
-
这题是多组数据
-
注意同一个村庄可能被摧毁多次
#include<bits/stdc++.h>
using namespace std;
#define lc u << 1
#define rc u << 1 | 1
constexpr int N = 1e5 + 10;
int n, m, a[N];
struct node{
int l, r;
int sum;
}tr[N * 4];
void pushup(int u) {
tr[u].sum = tr[lc].sum + tr[rc].sum;
}
void build(int u, int l, int r){
tr[u] = {l, r, 1};
if(l == r) return ;
int mid = l + r >> 1;
build(lc, l, mid), build(rc, mid + 1, r);
pushup(u);
}
// 单点修改
void modify(int u, int pos, int v){
if(tr[u].l == tr[u].r) {
tr[u].sum = v;
}
else{
int mid = tr[u].l + tr[u].r >> 1;
if(pos <= mid) modify(lc, pos, v);
else modify(rc, pos, v);
pushup(u);
}
}
// 区间求和
int query(int u, int l, int r){
if(tr[u].l >= l && tr[u].r <= r) {
return tr[u].sum;
}
int mid = tr[u].l + tr[u].r >> 1;
int sum = 0;
if(l <= mid) sum = query(lc, l, r);
if(r > mid) sum += query(rc, l, r);
return sum;
}
int main(){
ios::sync_with_stdio(false);
cin.tie(0), cout.tie(0);
while (cin >> n >> m) {
stack<int> des;
build(1, 1, n);
while(m --) {
char op;
int x;
cin >> op;
if(op == 'D') {
cin >> x;
modify(1, x, 0);
des.push(x);
}
else if(op == 'R') {
modify(1, des.top(), 1);
des.pop();
}
else {
int res = 0;
cin >> x;
if(query(1, x, x) == 0) {
cout << "0\n";
continue ;
}
res = 1; // 至少包含自己
if(x != n) {
int l = x, r = n;
while(l < r) {
int mid = l + r + 1 >> 1;
if(query(1, x, mid) == mid - x + 1) l = mid;
else r = mid - 1;
}
res += l - x;
}
// cout << "fsRes: " << res << '\n';
if(x != 1) {
int l = 1, r = x;
while(l < r) {
int mid = l + r >> 1;
if(query(1, mid, x) == x - mid + 1) r = mid;
else l = mid + 1;
}
res += x - l;
}
cout << res << '\n';
}
}
}
return 0;
}
Billboard
- 维护区间最大值
- 注意 h 非常大
#include<bits/stdc++.h>
using namespace std;
#define lc u << 1
#define rc u << 1 | 1
constexpr int N = 2e5 + 10;
int n, h, w, a[N];
struct node{
int l, r;
int sum, mx;
}tr[N * 4];
void pushup(int u) {
tr[u].sum = tr[lc].sum + tr[rc].sum;
tr[u].mx = max(tr[lc].mx, tr[rc].mx);
}
void build(int u, int l, int r){
tr[u] = {l, r, w, w};
if(l == r) return ;
int mid = l + r >> 1;
build(lc, l, mid), build(rc, mid + 1, r);
pushup(u);
}
// 单点修改
void modify(int u, int pos, int v){
if(tr[u].l == tr[u].r) {
tr[u].sum += v;
tr[u].mx += v;
}
else{
int mid = tr[u].l + tr[u].r >> 1;
if(pos <= mid) modify(lc, pos, v);
else modify(rc, pos, v);
pushup(u);
}
}
// 区间求和
int query(int u, int l, int r){
if(tr[u].l >= l && tr[u].r <= r) {
return tr[u].sum;
}
int mid = tr[u].l + tr[u].r >> 1;
int sum = 0;
if(l <= mid) sum = query(lc, l, r);
if(r > mid) sum += query(rc, l, r);
return sum;
}
int askMax(int u, int l, int r) {
if(tr[u].l >= l && tr[u].r <= r) {
return tr[u].mx;
}
int mid = tr[u].l + tr[u].r >> 1;
int mx = -2E9;
if(l <= mid) mx = max(mx, askMax(lc, l, r));
if(r > mid) mx = max(mx, askMax(rc, l, r));
return mx;
}
int main(){
ios::sync_with_stdio(false);
cin.tie(0), cout.tie(0);
while (cin >> h >> w >> n) {
if(h > n) h = n;
build(1, 1, h);
for(int i = 1; i <= n; i ++){
int x;
cin >> x;
if(askMax(1, 1, h) < x) {
cout << "-1\n";
continue ;
}
int l = 1, r = h;
while(l < r) {
int mid = l + r >> 1;
if(askMax(1, 1, mid) >= x) r = mid;
else l = mid + 1;
}
modify(1, l, -x);
cout << l << '\n';
}
}
return 0;
}
Coder
线段树的每条线段维护一个集合。
提前离散化建有序线段树。
查询 sum 的时候可以合并集合
// tr[u].sum[i] =
// tr[lc].sum[i] + tr[rc]
// (idx + lnum) % 5 = i
// idx = (i - lnum) % 5
// 1 2 3 | 1 2 3
// Ac : 3
// (1 - 3) 3
#include<bits/stdc++.h>
using namespace std;
#define lc u << 1
#define rc u << 1 | 1
constexpr int N = 1E5 + 10;
int n, a[N], tmp[N], x[N], s;
string op[N];
struct node {
int l, r, cnt;
long long sum[5];
}tr[N << 2];
void pushup(int u) {
tr[u].cnt = tr[lc].cnt + tr[rc].cnt;
for(int i = 0; i < 5; i ++) {
int idx = (i - tr[lc].cnt % 5) % 5;
if(idx < 0) idx += 5;
tr[u].sum[i] = tr[lc].sum[i] + tr[rc].sum[idx];
}
}
void build(int u, int l, int r) {
tr[u] = {l, r, 0};
memset(tr[u].sum, 0, sizeof tr[u].sum);
if(l == r) return ;
int mid = l + r >> 1;
build(lc, l, mid);
build(rc, mid + 1, r);
pushup(u);
}
void update(int u, int p, int v) {
if(tr[u].l == tr[u].r) {
if(v < 0) tr[u].cnt --;
else tr[u].cnt ++;
tr[u].sum[0] += v;
return ;
}
int mid = tr[u].l + tr[u].r >> 1;
if(p <= mid) update(lc, p, v);
else update(rc, p, v);
pushup(u);
}
int main(){
ios::sync_with_stdio(false);
cin.tie(0), cout.tie(0);
while(cin >> n) {
s = 0;
for(int i = 1; i <= n; i ++) {
cin >> op[i];
if(op[i][0] != 's') {
cin >> x[i];
if(op[i][0] == 'a') {
a[++ s] = x[i];
tmp[s] = x[i];
}
}
}
sort(tmp + 1, tmp + s + 1);
int sz = unique(tmp + 1, tmp + s + 1) - tmp - 1;
build(1, 1, sz);
for(int i = 1; i <= n; i ++) {
if(op[i][0] == 'a') {
int idx = lower_bound(tmp + 1, tmp + sz + 1, x[i]) - tmp;
update(1, idx, x[i]);
}
else if(op[i][0] == 'd') {
int idx = lower_bound(tmp + 1, tmp + sz + 1, x[i]) - tmp;
update(1, idx, -x[i]);
}
else {
cout << tr[1].sum[2] << '\n';
}
}
}
return 0;
}