改写有序表练习(1)——达标的子数组数量

204 阅读4分钟

持续创作,加速成长!这是我参与「掘金日新计划 · 6 月更文挑战」的第18天,点击查看活动详情

题目

在这里插入图片描述

力扣链接

区间和的个数

举例(暴力解)

在这里插入图片描述

枚举所有位置开头的子数组(O(N^2)),再求每一个子数组的累加和(O(N));所以暴力解总的时间复杂度为O(N ^ 3)

但是求累加和的时候可以用前缀和优化一下,这样最后一步判断子数组累加和的范围是否达标的时间复杂度就是O(1);但是仍然要枚举所有的子数组,所以总的时间复杂度还是挺高,为O(N ^ 2)

改写归并的方法

归并排序面试题——区间和的个数

改写有序表的方法(O(N*logN))

小结论

子数组必须以0位置结尾的情况下,落在指定范围上的有几个;子数组必须以1位置结尾的情况下,落在指定范围上的有几个;子数组必须以2位置结尾的情况下,落在指定范围上的有几个;子数组必须以3位置结尾的情况下,落在指定范围上的有几个;每一个结尾位置的情况下,落在指定范围上的有几个,所有情况都累加起来就是总的答案。

按这种情况讨论,所有这些子数组,每一个都是不一样的,因为最后一个数位置不一样;以0位置结尾的子数组,最后一个数就必须是0位置的数,以1位置结尾的情况下,最后一个数就必须是1位置的数...

所有的子数组一定都列全了,因为客观上来讲,任何一个子数组,它都一定有一个结尾的位置。所以我们以每个位置结尾来划分子数组,既没有划多,也没有划少, 那么在求0位置结尾的情况下达标有几个,意味着结尾的情况达标几个;如果把每一个位置结尾的子数组都求出正确的数量,都累加起来,就是整体达标的数量

代入具体的例子

假设已经来到了i位置,我想求子数组以i位置结尾的情况下,子数组的累加和落在 [10,60] 的有多少个?

已知条件:0~i 的累加和为100;

那么原问题可以转换为:从0位置开始,有多少个前缀和范围落在 [40,90] 上

在这里插入图片描述

假设当前来到了17位置,并且0 ~ 17位置整体的累加和是100,0 ~ 13位置整体的累加和是50,那么很容易推出14 ~ 17位置的累加和是50。

根据上面的结论,任何一个前缀和只要落在 [40,90]范围,就与之一定能够找到一个以17位置结尾的子数组落在 [10,60] 范围

在这里插入图片描述

抽象化

0 ~i 的累加和为S。求以i位置结尾的子数组累加和落在 [a,b]范围上有多少个?等同于求0位置开始,有多少个前缀和范围落在 [S-b,S-a]范围上。求出了前面有多少个前缀和落在指定范围上,就等同于转换出了有多少个子数组必须以i结尾的累加和落在指定范围上。

在这里插入图片描述

具体例子

在这里插入图片描述

要求实现一个这样的结构

  • 能够往里面加入一个整数
  • 能够范围查询,查询结构里有多少个数落在 [a,b] 范围上(有序表并没有这个功能,但是可以通过查询<key的数量拼出来
  • 可以接收重复数字(经典的有序表不能接受重复数字)

现在的问题如何解决接收重复数字

在这里插入图片描述

如上图,如何知道小于16的key有多少个?

经典的有序表,如AVL树、SB树,都是二叉搜索树,所以按照二叉搜索树的性质去查询即可,

在这里插入图片描述

往右滑就累加,往左滑就不变。所以上图小于46的key有95个

为什么系统实现的有序表不实现这个功能,因为系统不知道你要用啥功能,它咋知道给你用什么样的辅助数据项呢,他不知道,所以他不支持。如果下回给个别的目标,要别的数据项,就需要设计别的数据项,如果这样下去,这个结构就做得非常的重了, 这不利于这个结构的推广。所以作为系统级的结构,只实现有限的功能,

代码

package com.harrison.class25;

import java.util.HashSet;

/**
 * @author Harrison
 * @create 2022-04-06-17:34
 * @motto 众里寻他千百度,蓦然回首,那人却在灯火阑珊处。
 */
public class Code04_CountOfRangeSum {
    public static int countRangeSum1(int[] nums, int lower, int upper) {
        int n = nums.length;
        long[] sums = new long[n + 1];
        for (int i = 0; i < n; ++i)
            sums[i + 1] = sums[i] + nums[i];
        return countWhileMergeSort(sums, 0, n + 1, lower, upper);
    }

    private static int countWhileMergeSort(long[] sums, int start, int end, int lower, int upper) {
        if (end - start <= 1)
            return 0;
        int mid = (start + end) / 2;
        int count = countWhileMergeSort(sums, start, mid, lower, upper)
                + countWhileMergeSort(sums, mid, end, lower, upper);
        int j = mid, k = mid, t = mid;
        long[] cache = new long[end - start];
        for (int i = start, r = 0; i < mid; ++i, ++r) {
            while (k < end && sums[k] - sums[i] < lower)
                k++;
            while (j < end && sums[j] - sums[i] <= upper)
                j++;
            while (t < end && sums[t] < sums[i])
                cache[r++] = sums[t++];
            cache[r] = sums[i];
            count += j - k;
        }
        System.arraycopy(cache, 0, sums, start, t - start);
        return count;
    }

    public static class SBTNode{
        public long key;
        public SBTNode l;
        public SBTNode r;
        public long size;// 不同key的size
        public long all;// 总的size

        public SBTNode(long k){
            key=k;
            size=1;
            all=1;
        }
    }

    public static class SizeBalanceTreeSet{
        private SBTNode root;
        private HashSet<Long> set=new HashSet<>();

        private SBTNode rightRotate(SBTNode cur){
            long same=cur.all-(cur.l!=null?cur.l.all:0)-(cur.r!=null?cur.r.all:0);
            SBTNode leftNode=cur.l;
            cur.l=leftNode.r;
            leftNode.r=cur;
            leftNode.size=cur.size;
            cur.size=(cur.l!=null?cur.l.size:0)+(cur.r!=null?cur.r.size:0)+1;
            // all modify
            leftNode.all=cur.all;
            cur.all=(cur.l!=null?cur.l.all:0)+(cur.r!=null?cur.r.all:0)+same;
            return leftNode;
        }

        private SBTNode leftRotate(SBTNode cur) {
            long same = cur.all - (cur.l != null ? cur.l.all : 0) - (cur.r != null ? cur.r.all : 0);
            SBTNode rightNode = cur.r;
            cur.r = rightNode.l;
            rightNode.l = cur;
            rightNode.size = cur.size;
            cur.size = (cur.l != null ? cur.l.size : 0) + (cur.r != null ? cur.r.size : 0) + 1;
            // all modify
            rightNode.all = cur.all;
            cur.all = (cur.l != null ? cur.l.all : 0) + (cur.r != null ? cur.r.all : 0) + same;
            return rightNode;
        }

        private SBTNode maintain(SBTNode cur) {
            if (cur == null) {
                return null;
            }
            long leftSize = cur.l != null ? cur.l.size : 0;
            long leftLeftSize = cur.l != null && cur.l.l != null ? cur.l.l.size : 0;
            long leftRightSize = cur.l != null && cur.l.r != null ? cur.l.r.size : 0;
            long rightSize = cur.r != null ? cur.r.size : 0;
            long rightLeftSize = cur.r != null && cur.r.l != null ? cur.r.l.size : 0;
            long rightRightSize = cur.r != null && cur.r.r != null ? cur.r.r.size : 0;
            if (leftLeftSize > rightSize) {
                cur = rightRotate(cur);
                cur.r = maintain(cur.r);
                cur = maintain(cur);
            } else if (leftRightSize > rightSize) {
                cur.l = leftRotate(cur.l);
                cur = rightRotate(cur);
                cur.l = maintain(cur.l);
                cur.r = maintain(cur.r);
                cur = maintain(cur);
            } else if (rightRightSize > leftSize) {
                cur = leftRotate(cur);
                cur.l = maintain(cur.l);
                cur = maintain(cur);
            } else if (rightLeftSize > leftSize) {
                cur.r = rightRotate(cur.r);
                cur = leftRotate(cur);
                cur.l = maintain(cur.l);
                cur.r = maintain(cur.r);
                cur = maintain(cur);
            }
            return cur;
        }

        private SBTNode add(SBTNode cur,long key,boolean contains){
            if(cur==null){
                return new SBTNode(key);
            }else{
                cur.all++;
                if(key==cur.key){
                    return cur;
                }else{// 还在左滑或者右滑
                    if(!contains){
                        cur.size++;
                    }
                    if(key<cur.key){
                        cur.l=add(cur.l,key,contains);
                    }else{
                        cur.r=add(cur.r,key,contains);
                    }
                    return maintain(cur);
                }
            }
        }

        public void add(long sum){
            boolean contains=set.contains(sum);
            root=add(root,sum,contains);
            set.add(sum);
        }

        public long lessKeySize(long key){
            SBTNode cur=root;
            long ans=0;
            while(cur!=null){
                if(key==cur.key){
                    // 滑到最后的一个边界
                    return ans+=(cur.l!=null?cur.l.all:0);
                }else if(key<cur.key){
                    cur=cur.l;
                }else{
                    ans+=cur.all-(cur.r!=null?cur.r.all:0);
                    cur=cur.r;
                }
            }
            return ans;
        }

        // > 7 8...
        // <8 ...<=7
        public long moreKeySize(long key) {
            return root != null ? (root.all - lessKeySize(key + 1)) : 0;
        }
    }

    public static int countRangeSum2(int[] nums,int lower,int upper){
        // 黑盒,加入数字(前缀和),不去重,可以接受重复数字
        // < num , 有几个数?
        SizeBalanceTreeSet treeSet=new SizeBalanceTreeSet();
        long sum=0;
        int ans=0;
        treeSet.add(0);// 一个数都没有的时候,就已经有一个前缀和累加为0
        for(int i=0; i<nums.length; i++){
            sum+=nums[i];
            // [sum - upper, sum - lower]
            // [10, 20] ?
            // < 10 ?  < 21 ?
            long a=treeSet.lessKeySize(sum-lower+1);
            long b=treeSet.lessKeySize(sum-upper);
            ans+=a-b;
            treeSet.add(sum);
        }
        return ans;
    }

    public static void printArray(int[] arr) {
        for (int i = 0; i < arr.length; i++) {
            System.out.print(arr[i] + " ");
        }
        System.out.println();
    }

    // for test
    public static int[] generateArray(int len, int varible) {
        int[] arr = new int[len];
        for (int i = 0; i < arr.length; i++) {
            arr[i] = (int) (Math.random() * varible);
        }
        return arr;
    }

    public static void main(String[] args) {
        int len = 200;
        int varible = 50;
        System.out.println("test begin");
        for (int i = 0; i < 10000; i++) {
            int[] test = generateArray(len, varible);
            int lower = (int) (Math.random() * varible) - (int) (Math.random() * varible);
            int upper = lower + (int) (Math.random() * varible);
            int ans1 = countRangeSum1(test, lower, upper);
            int ans2 = countRangeSum2(test, lower, upper);
            if (ans1 != ans2) {
                System.out.println("oops");
                printArray(test);
                System.out.println(lower);
                System.out.println(upper);
                System.out.println(ans1);
                System.out.println(ans2);
            }
        }
        System.out.println("test finish");
    }
}