线段树

129 阅读4分钟

介绍

1,一种支持范围整体修改和范围整体查询的数据结构

2,解决的问题范畴: 大范围信息可以只由左、右两侧信息加工出, 而不必遍历左右两个子范围的具体状况

线段树实例一

给定一个数组arr,用户希望你实现如下三个方法

1)void add(int L, int R, int V) : 让数组arr[L…R]上每个数都加上V

2)void update(int L, int R, int V) : 让数组arr[L…R]上每个数都变成V

3)int sum(int L, int R) :让返回arr[L…R]这个范围整体的累加和

怎么让这三个方法,时间复杂度都是O(logN)

public static class SegmentTree {
    // arr[]为原序列的信息从0开始,但在arr里是从1开始的
    // sum[]模拟线段树维护区间和
    // lazy[]为累加和懒惰标记
    // change[]为更新的值
    // update[]为更新慵懒标记
    private int MAXN;
    private int[] arr;
    private int[] sum;
    private int[] lazy;
    private int[] change;
    private boolean[] update;

    public SegmentTree(int[] origin) {
        MAXN = origin.length + 1;
        arr = new int[MAXN]; // arr[0] 不用 从1开始使用
        for (int i = 1; i < MAXN; i++) {
            arr[i] = origin[i - 1];
        }
        sum = new int[MAXN << 2]; // 用来支持脑补概念中,某一个范围的累加和信息
        lazy = new int[MAXN << 2]; // 用来支持脑补概念中,某一个范围沒有往下傳遞的纍加任務
        change = new int[MAXN << 2]; // 用来支持脑补概念中,某一个范围有没有更新操作的任务
        update = new boolean[MAXN << 2]; // 用来支持脑补概念中,某一个范围更新任务,更新成了什么
    }

    private void pushUp(int rt) {
        sum[rt] = sum[rt << 1] + sum[rt << 1 | 1];
    }

    // 之前的,所有懒增加,和懒更新,从父范围,发给左右两个子范围
    // 分发策略是什么
    // ln表示左子树元素结点个数,rn表示右子树结点个数
    private void pushDown(int rt, int ln, int rn) {
        if (update[rt]) {
            update[rt << 1] = true;
            update[rt << 1 | 1] = true;
            change[rt << 1] = change[rt];
            change[rt << 1 | 1] = change[rt];
            lazy[rt << 1] = 0;
            lazy[rt << 1 | 1] = 0;
            sum[rt << 1] = change[rt] * ln;
            sum[rt << 1 | 1] = change[rt] * rn;
            update[rt] = false;
        }
        if (lazy[rt] != 0) {
            lazy[rt << 1] += lazy[rt];
            sum[rt << 1] += lazy[rt] * ln;
            lazy[rt << 1 | 1] += lazy[rt];
            sum[rt << 1 | 1] += lazy[rt] * rn;
            lazy[rt] = 0;
        }
    }

    // 在初始化阶段,先把sum数组,填好
    // 在arr[l~r]范围上,去build,1~N,
    // rt : 这个范围在sum中的下标
    public void build(int l, int r, int rt) {
        if (l == r) {
            sum[rt] = arr[l];
            return;
        }
        int mid = (l + r) >> 1;
        build(l, mid, rt << 1);
        build(mid + 1, r, rt << 1 | 1);
        pushUp(rt);
    }


    // L~R  所有的值变成C
    // l~r  rt
    public void update(int L, int R, int C, int l, int r, int rt) {
        if (L <= l && r <= R) {
            update[rt] = true;
            change[rt] = C;
            sum[rt] = C * (r - l + 1);
            lazy[rt] = 0;
            return;
        }
        // 当前任务躲不掉,无法懒更新,要往下发
        int mid = (l + r) >> 1;
        pushDown(rt, mid - l + 1, r - mid);
        if (L <= mid) {
            update(L, R, C, l, mid, rt << 1);
        }
        if (R > mid) {
            update(L, R, C, mid + 1, r, rt << 1 | 1);
        }
        pushUp(rt);
    }

    // L~R, C 任务!
    // rt,l~r
    public void add(int L, int R, int C, int l, int r, int rt) {
        // 任务如果把此时的范围全包了!
        if (L <= l && r <= R) {
            sum[rt] += C * (r - l + 1);
            lazy[rt] += C;
            return;
        }
        // 任务没有把你全包!
        // l  r  mid = (l+r)/2
        int mid = (l + r) >> 1;
        pushDown(rt, mid - l + 1, r - mid);
        // L~R
        if (L <= mid) {
            add(L, R, C, l, mid, rt << 1);
        }
        if (R > mid) {
            add(L, R, C, mid + 1, r, rt << 1 | 1);
        }
        pushUp(rt);
    }

    // 1~6 累加和是多少? 1~8 rt
    public long query(int L, int R, int l, int r, int rt) {
        if (L <= l && r <= R) {
            return sum[rt];
        }
        int mid = (l + r) >> 1;
        pushDown(rt, mid - l + 1, r - mid);
        long ans = 0;
        if (L <= mid) {
            ans += query(L, R, l, mid, rt << 1);
        }
        if (R > mid) {
            ans += query(L, R, mid + 1, r, rt << 1 | 1);
        }
        return ans;
    }

}

线段树实例二

leetcode.cn/problems/fa…

public static class SegmentTree {
   private int[] max;
   private int[] change;
   private boolean[] update;

   public SegmentTree(int size) {
      int N = size + 1;
      max = new int[N << 2];

      change = new int[N << 2];
      update = new boolean[N << 2];
   }

   private void pushUp(int rt) {
      max[rt] = Math.max(max[rt << 1], max[rt << 1 | 1]);
   }

   // ln表示左子树元素结点个数,rn表示右子树结点个数
   private void pushDown(int rt, int ln, int rn) {
      if (update[rt]) {
         update[rt << 1] = true;
         update[rt << 1 | 1] = true;
         change[rt << 1] = change[rt];
         change[rt << 1 | 1] = change[rt];
         max[rt << 1] = change[rt];
         max[rt << 1 | 1] = change[rt];
         update[rt] = false;
      }
   }

   public void update(int L, int R, int C, int l, int r, int rt) {
      if (L <= l && r <= R) {
         update[rt] = true;
         change[rt] = C;
         max[rt] = C;
         return;
      }
      int mid = (l + r) >> 1;
      pushDown(rt, mid - l + 1, r - mid);
      if (L <= mid) {
         update(L, R, C, l, mid, rt << 1);
      }
      if (R > mid) {
         update(L, R, C, mid + 1, r, rt << 1 | 1);
      }
      pushUp(rt);
   }

   public int query(int L, int R, int l, int r, int rt) {
      if (L <= l && r <= R) {
         return max[rt];
      }
      int mid = (l + r) >> 1;
      pushDown(rt, mid - l + 1, r - mid);
      int left = 0;
      int right = 0;
      if (L <= mid) {
         left = query(L, R, l, mid, rt << 1);
      }
      if (R > mid) {
         right = query(L, R, mid + 1, r, rt << 1 | 1);
      }
      return Math.max(left, right);
   }

}

public HashMap<Integer, Integer> index(int[][] positions) {
   TreeSet<Integer> pos = new TreeSet<>();
   for (int[] arr : positions) {
      pos.add(arr[0]);
      pos.add(arr[0] + arr[1] - 1);
   }
   HashMap<Integer, Integer> map = new HashMap<>();
   int count = 0;
   for (Integer index : pos) {
      map.put(index, ++count);
   }
   return map;
}

public List<Integer> fallingSquares(int[][] positions) {
   HashMap<Integer, Integer> map = index(positions);
   int N = map.size();
   SegmentTree segmentTree = new SegmentTree(N);
   int max = 0;
   List<Integer> res = new ArrayList<>();
   // 每落一个正方形,收集一下,所有东西组成的图像,最高高度是什么
   for (int[] arr : positions) {
      int L = map.get(arr[0]);
      int R = map.get(arr[0] + arr[1] - 1);
      int height = segmentTree.query(L, R, 1, N, 1) + arr[1];
      max = Math.max(max, height);
      res.add(max);
      segmentTree.update(L, R, height, 1, N, 1);
   }
   return res;
}