如何求解区间和在某个范围内的个数?

235 阅读2分钟

题目描述

给你一个整数数组 nums 以及两个整数 lower 和 upper 。求数组中,值位于范围 [lower, upper] (包含 lower 和 upper)之内的 区间和的个数 。区间和 S(i, j) 表示在 nums 中,位置从 i 到 j 的元素之和,包含 i 和 j (i ≤ j)。

样例

  • 输入:nums = [-2,5,-1], lower = -2, upper = 2
  • 输出:3
  • 解释:存在三个区间:[0,0]、[2,2] 和 [0,2] ,对应的区间和分别是:-2 、-1 、2 。

前言

希望看到文章的朋友能够完全看完本篇文章的所有内容,如果你只关注最后的解法没有任何意义,我更希望的是本篇文章的分享能够让你有所收获,所谓授人与鱼,不如授之与渔。

范围(这个很重要,我认为所有不给范围或者给错范围的题都是耍流氓)

  • 1 <= nums.length <= 10^5
  • -2^31 <= nums[i] <= 2^31 - 1
  • -10^5 <= lower <= upper <= 10^5
  • 题目数据保证答案是一个 32 位 的整数

题目分析

这题的意思理解起来比较简单,也就是求任意区间和范围在[lower, upper]内的数量,相信很多同学马上就想到了最简单粗暴的解法,那就是记录前缀和暴力解,pre[j] - pre[i]比较一下是否在范围内,如果在,那么结果数加一。当然这题不会这么简单,这个时候你需要注意一下范围,数组的长度极限可以是10^5,那么暴力O(n^2)还能解吗?我做题流程一般是先理解题意,在看范围,然后在去想解法。

解题思路

既然是求解任意区间和在某个范围内的个数,我们试着写一下比较的表达式,lower <= pre[j] - pre[i] <= upper,我们该如何利用这个表达式呢?这个时候递推思维就起作用了,我们可以把这个问题转化成求以下标j结尾,下标区间[0..j-1, j]的解数,最后把所有的加上不就是答案了么,然后有同学会说,这不就是与上面讲的暴力解法是一样的么,是的,思路确实是一致的,但是我们不能暴力,原因也解释了,那么我们该如何解呢?我们来把表达式转换一下,我们能够得出pre[i]的范围是[pre[j] - upper, pre[j] - lower],讲到这里我们已经把这个问题变成了区间内求数量的问题了,那么我们就可以利用树状数组和线段树来解决!

需要注意的一些细节

  • nums[i]的数据量太大,但nums.length只有10^5,因此我们可以将数据离散化在解
  • 必须要考虑前缀没有的情况,也就是取pre[j] - pre[0]

树状数组解法

public int countRangeSum(int[] a, int mi, int mx) {
    long[] pre = new long[a.length + 1];
    Set<Long> treeSet = new TreeSet<>();
    for (int i = 0; i <= a.length; i++) {
        if (i > 0) {
            pre[i] = pre[i - 1] + a[i - 1];
        }
        treeSet.add(pre[i]);
        treeSet.add(pre[i] - mx);
        treeSet.add(pre[i] - mi);
    }
    // 数据量太大,我们需要将其离散化
    Map<Long, Integer> map = new HashMap<>();
    int count = 0;
    for (Long aLong : treeSet) {
        map.put(aLong, ++count);
    }
    int ans = 0;
    int[] res = new int[count + 1];
    for (int i = 0; i <= a.length; i++) {
        ans += sum(map.get(pre[i] - mi), res) - sum(map.get(pre[i] - mx) - 1, res);
        add(map.get(pre[i]), 1, count, res);
    }
    return ans;
}

public int lowBit(int x) {
    return x & (-x);
}

public void add(int x, int val, int n, int[] res) {
    while (x <= n) {
        res[x] += val;
        x += lowBit(x);
    }
}

public int sum(int x, int[] res) {
    int ans = 0;
    while (x > 0) {
        ans += res[x];
        x -= lowBit(x);
    }
    return ans;
}

线段树解法

static class Tree{
    int l, r, sum;
    public Tree(int l, int r) {
        this.l = l;
        this.r = r;
    }
}

public void pushUp(int k, Tree[] tree) {
    tree[k].sum = tree[k << 1].sum + tree[k << 1 | 1].sum;
}

public int countRangeSum(int[] a, int mi, int mx) {
    long[] pre = new long[a.length + 1];
    Set<Long> treeSet = new TreeSet<>();
    for (int i = 0; i <= a.length; i++) {
        if (i > 0) {
            pre[i] = pre[i - 1] + a[i - 1];
        }
        treeSet.add(pre[i]);
        treeSet.add(pre[i] - mx);
        treeSet.add(pre[i] - mi);
    }
    // 数据量太大,我们需要将其离散化
    Map<Long, Integer> map = new HashMap<>();
    int count = 0;
    for (Long aLong : treeSet) {
        map.put(aLong, ++count);
    }
    Tree[] tree = new Tree[treeSet.size() << 2 | 1];
    build(1, treeSet.size(), 1, tree);
    int ans = 0;
    for (int i = 0; i <= a.length; i++) {
        ans += query(map.get(pre[i] - mx), map.get(pre[i] - mi), 1, tree);
        insert(map.get(pre[i]), 1, 1, tree);
    }
    return ans;
}

private void insert(int pos, int val, int k, Tree[] tree) {
    if (tree[k].l == tree[k].r && tree[k].l == pos) {
        tree[k].sum += val;
        return;
    }
    int mid = (tree[k].l + tree[k].r) >> 1;
    if (pos <= mid) {
        insert(pos, val, k << 1, tree);
    } else {
        insert(pos, val, k << 1 | 1, tree);
    }
    pushUp(k, tree);
}

private int query(int l, int r, int k, Tree[] tree) {
    if (l <= tree[k].l && tree[k].r <= r) {
        return tree[k].sum;
    }
    int mid = (tree[k].l + tree[k].r) >> 1, ans = 0;
    if (l <= mid) {
        ans += query(l, r, k << 1, tree);
    }
    if (r > mid) {
        ans += query(l, r, k << 1 | 1, tree);
    }
    return ans;
}

private void build(int l, int r, int k, Tree[] tree) {
    tree[k] = new Tree(l, r);
    if(l == r) {
        tree[k].sum = 0;
        return;
    }
    int mid = (l + r) >> 1;
    build(l, mid, k << 1, tree);
    build(mid + 1, r, k << 1 | 1, tree);
    pushUp(k, tree);
}

归并排序解法

这个解法我也是后边才会的,确实很有意思,将前缀和数组进行归并排序,在排序的过程中,计算已排好序的左右子数组,算出合并子数组能够得到的答案,这样说起来比较抽象,具体一点就是我们假设lLeft是左数组的节点,要统计右数组中有多个数减去pre[lLeft]满足要求,只要双指针一直往后边移动就可以了

public int countRangeSum(int[] nums, int lower, int upper) {
    long[] pre = new long[nums.length + 1];
    for (int i = 1; i <= nums.length; ++i) {
        pre[i] = pre[i - 1] + nums[i - 1];
    }
    return countRangeSumMergeSort(pre, lower, upper, 0, pre.length - 1);
}

private int countRangeSumMergeSort(long[] pre, int lower, int upper, int l, int r) {
    if (l == r) {
        return 0;
    }
    int mid = (l + r) >> 1, ans = 0;
    ans += countRangeSumMergeSort(pre, lower, upper, l, mid);
    ans += countRangeSumMergeSort(pre, lower, upper, mid + 1, r);
    return ans + doMerge(pre, lower, upper, l, mid, r);
}

private int doMerge(long[] pre, int lower, int upper, int l, int mid, int r) {
    int lLeft = l, rLeft = mid + 1, rRight, ans = 0;
    while (lLeft <= mid) {
        while (rLeft <= r && pre[rLeft] - pre[lLeft] < lower) {
            rLeft++;
        }
        rRight = rLeft;
        while (rRight <= r && pre[rRight] - pre[lLeft] <= upper) {
            rRight++;
        }
        ans += (rRight - rLeft);
        lLeft++;
    }
    long[] sorted = new long[r - l + 1];
    int p = 0, li = l, ri = mid + 1;
    while (li <= mid || ri <= r) {
        if (li > mid) {
            sorted[p++] = pre[ri++];
        } else if (ri > r) {
            sorted[p++] = pre[li++];
        } else {
            if (pre[li] > pre[ri]) {
                sorted[p++] = pre[ri++];
            } else {
                sorted[p++] = pre[li++];
            }
        }
    }
    for (int i = 0; i < sorted.length; i++) {
        pre[i + l] = sorted[i];
    }
    return ans;
}