【数据结构与算法】单调栈详解

3,452 阅读11分钟

单调栈

1. 引入

单调栈解决的是这样一个问题,比如说给一个数组arr = { 5,4,6,7,2,3,0,1 },我想知道每一个元素左边比该元素大且离得最近的元素和右边比该元素大且离得最近的元素都是什么。

如果数组有 N 个元素,经典解法就是来到 i 位置,左边遍历直到比 arr[i] 大的元素为止,右边遍历直到比 arr[i] 大的元素为止。确定一个位置的时间复杂度为O(N),确定 N 个位置的时间复杂度就是O(N^2)。

能不能将确定 N 个位置的时间复杂度降到O(N)?单调栈结构。

同样,如果使用单调栈能够找到每一个元素左边和右边比该元素大且离得最近的元素,同样也能找到每个元素左边和右边比该元素小且离得最近的元素。

2. 流程(无重复)

单调栈本身是支持数组中有重复值的,但是我们为了讲清原理,举得例子中数组是没有重复值的。

首先,准备一个栈。

栈中存储的是数组中元素的下标。为什么不存储元素?是因为下标不仅仅能够表示元素,还能表示元素在数组中的位置,携带的信息更多。

如果要找到数组中每一个元素左右两边比该元素大且离得最近的元素,那么单调栈要保证从栈底到栈顶存储的下标对应的元素是从大到小的。

如果要找到数组中每一个元素左右两边比该元素小且离得最近的元素,那么单调栈要保证从栈底到栈顶存储的下标对应的元素是从小到大的。

本案例只找比该元素大且离得最近的元素。

20211016003349.png

从头开始遍历数组:

  • 如果栈中没有元素,直接将元素的下标压栈。
  • 如果栈中有元素,当前元素和栈顶的下标所指向的元素进行比较:
    • 当前元素比栈顶的下标所指向的元素小,将当前元素的下标压栈。
    • 当前元素比栈顶的下标所指向的元素大,栈顶的下标弹栈,同时记录原栈顶下标对应的元素的信息。原栈顶下标对应的元素左边比该元素大且离得最近的元素就是在栈中原栈顶下标压在下面的相邻下标对应的元素;原栈顶下标对应的元素右边比该元素大且离得最近的元素就是让它的下标弹栈的下标对应的元素。记录完之后,当前元素继续和新栈顶下标对应的元素进行比较。如果栈中只有一个下标,则该下标左边没有比该下标对应的元素大且离得最近的元素,右边正常。

当数组遍历完后,如果栈中还有下标,则进入清算阶段:

  • 如果不是最后一个下标,依次弹出栈顶下标,原栈顶下标对应的元素左边比该元素大的且离得最近的元素就是在栈中原栈顶下标压在下面的相邻下标;原栈顶下标对应的元素右边没有比该元素大的且离得最近的元素。
  • 是最后一个下标,弹出该下标,该下标对应的元素没有左边比该元素大的且离得最近的元素,也没有右边没有比该元素大的且离得最近的元素。

设计这种规则实际上就是在严格维护单调栈的单调性。

3. 流程(有重复)

假设数组中有重复值,那么单调栈中存储的元素就不能只是一个下标了,可能会存储多个下标,这多个下标对应的数组中的值是一样的。

因此在实现上,我们偏向去使用一个链表来作为单调栈的元素类型,同一个链表中所有下标指向的元素值是一样的

这种结构可以处理有重复值的数组,也可以处理无重复值的数组,是万能的。

20211016202434.png

流程上和无重复的大致相同,区别在于:

  • 当前元素比栈顶的下标链表所指向的元素大,栈顶的下标链表弹栈,同时记录原栈顶下标链表中每一个下标对应的元素的信息。原栈顶下标链表中每一个下标对应的元素左边比该元素大且离得最近的元素都是在栈中原栈顶下标链表压在下边的相邻下标链表的最后一个下标对应的元素;原栈顶下标链表中每一个下标对应的元素有右边比该元素大且离得最近的元素就是让它的下标链表弹栈的下标链表中的下标对应的元素(此时下标链表中只会有一个元素)。如果栈中只有一个下标链表,则该链表中所有下标左边没有比该下标对应的元素大且离得最近的元素,右边正常。
  • 当前元素与栈顶的下标链表所指向的元素相等,将该元素对应的下标连接到栈顶的下标链表的末尾。

4. 时间复杂度

为什么说使用单调栈可以将时间复杂度降低至O(N)?

假设有数组中有 N 个元素,在我们计算出了所有元素的左右边比该元素大或者小且离得最近的元素的整个过程中,无论是使用有重复的模型还是无重复的模型,每一个元素都只进栈一次,出栈一次。

5. 原理(无重复)

为什么数组中没有重复值,单调栈可以做到以O(N)的代价找到每个元素左右边比该元素大且离得最近的元素?

假设当前有一个单调栈,栈中有 a 和 b。 现在 c 要压栈,已知 arr[c] > arr[b],因此 b 需要先弹出栈,在 b 弹栈时记录 b 的相关信息。

为什么 arr[b] 左边比 arr[b] 大且离得最近的元素一定是 arr[a] ?

为什么 arr[b] 右边比 arr[b] 大且离得最近的元素一定是 arr[c] ?

20211016094715.png

证明:arr[b] 右边比 arr[b] 大且离得最近的元素是 arr[c]。

因为是从左往右依次遍历数组的,b 比 c 先进了栈,表示 arr[b] 比 arr[c] 先遍历到,因此 arr[c] 一定在 arr[b] 的右边。

那么我们看 b~c 之间,有没有可能有一个 k 使得 arr[k] > arr[b] ?

不可能,如果在 b~c 之间有一个 arr[k] > arr[b],那么肯定先遍历到 k,当遍历到 k 的时候,为了保证栈的单调性,一定会将 b 弹栈,根本轮不到 c 来让 b 弹栈。

证明:arr[b] 左边比 arr[b] 大且离得最近的元素是 arr[a]。

因为是从左往右依次遍历数组的,a 比 b 先进了栈,表示 arr[a] 比 arr[b] 先遍历到,因此 arr[a] 一定在 arr[b] 的左边。

那么我们看 a ~ b 之间,有没有可能有一个 k 使得 arr[k] > arr[b]?

此时我们知道 arr[a] > arr[b],arr[k] > arr[b],因此我们需要讨论一下 arr[a] 和 arr[k] 的关系:

  • 如果 arr[k] > arr[a],因为 k 一定比 a 后遍历到,因此在遍历到 k 时一定会让 a 弹栈,等遍历到 b 让 b 压栈时根本碰不到 a,因此这种情况不可能。
  • 如果 arr[k] < arr[a],因为 k 一定比 a 后遍历到,因此在遍历到 k 时 k 一定会压在 a 的上面。因为 arr[k] > arr[b],且 b 比 k 后遍历到,因此等遍历到 b 时,b 压栈时一定会和 a 之间隔一个 k,a 和 b 根本不会相邻,因此这种情况也不可能。

证明的目的就是要说明:使用单调栈这种结构来解决这种问题,每一步的流程和数据都是正确的。

6. 原理(有重复)

为什么数组中有重复值,单调栈也可以做到以O(N)的代价找到每个元素左右边比该元素大且离得最近的元素?

假设当前有一个单调栈,栈中有 a、b、c、d 和 e。 现在 e 要压栈,已知 arr[e] > arr[d],因此 d 需要先弹出栈,在 d 弹栈时记录 b 的相关信息。

为什么 arr[d] 左边比 arr[d] 大且离得最近的元素一定是 arr[c] ?

为什么 arr[d] 右边比 arr[d] 大且离得最近的元素一定是 arr[e] ?

20211016182642.png

证明:arr[d] 右边比 arr[d] 大且离得最近的元素是 arr[e]。

与无重复同理。

证明:arr[d] 左边比 arr[d] 大且离得最近的元素是 arr[c]。

因为是从左往右依次遍历数组的,c 比 d 先进了栈,表示 arr[c] 比 arr[d] 先遍历到,因此 arr[c] 一定在 arr[d] 的左边。

那么我们看 c ~ d 之间,有没有可能有一个 k 使得 arr[k] > arr[d]?

此时我们知道 arr[c] > arr[d],arr[k] > arr[d],因此我们需要讨论一下 arr[c] 和 arr[k] 的关系:

  • 如果 arr[k] > arr[c],因为 k 一定比 c 后遍历到,因此在遍历到 k 时一定会让 c 弹栈,等遍历到 d 让 d 压栈时根本碰不到 c,因此这种情况不可能。
  • 如果 arr[k] == arr[c],因为 k 一定比 c 后遍历到,因此在遍历到 k 时一定会让 k 连接到 c 后面,等遍历到 e 让 d 弹栈时,d 就会对 k 收集左边的信息,而不是 c,因此这种情况不可能。
  • 如果 arr[k] < arr[c],因为 k 一定比 c 后遍历到,因此在遍历到 k 时 k 一定会压在 c 的上面。因为 arr[k] > arr[d],且 d 比 k 后遍历到,因此等遍历到 d 时,d 压栈时一定会和 c 的下标链表之间隔一个 k,d 和 c 根本不会相邻,因此这种情况也不可能。

7. 实现

下面代码实现的是支持有重复,能找到arr[i]左右两边比arr[i]小且离得最近的元素的单调栈。

public class MonotoneStack {

    private Stack<LinkedList<Integer>> stack = new Stack<>();

    private int[] leftRecord;

    private int[] rightRecord;

    public MonotoneStack(int[] arr) {
        this.leftRecord = new int[arr.length];
        this.rightRecord = new int[arr.length];
        record(arr);
    }

    // 构建记录
    public void record(int[] arr) {
        for (int i = 0; i < arr.length; i ++) {
            // 如果栈空或者arr[i]大于栈顶下标链表对应的元素的值,直接压栈
            if (stack.isEmpty() || arr[i] > arr[stack.peek().getLast()]) {
                LinkedList<Integer> list = new LinkedList<>();
                list.addFirst(i);
                stack.push(list);
            }
            // 如果arr[i]等于栈顶下标链表对应的元素的值,连接到栈顶下标链表的末尾
            else if (arr[i] == arr[stack.peek().getLast()]) {
                stack.peek().addLast(i);
            }
            // 如果arr[i]小于栈顶下标链表对应的元素的值,栈顶下标链表弹栈
            else {
                // 直到arr[i]大于等于栈顶下标链表对应的元素值为止
                while (arr[i] < arr[stack.peek().getLast()]) {
                    LinkedList<Integer> list = stack.pop();

                    // 判断是否是栈底下标链表
                    if (stack.isEmpty()) {
                        for (Integer index : list) {
                            leftRecord[index] = index;
                            rightRecord[index] = i;
                        }
                        break;
                    }

                    // 遍历链表
                    for (Integer index : list) {
                        // 左边比arr[i]小且离得最近的是原栈顶下标链表下面链表的最后一个下标
                        leftRecord[index] = stack.peek().getLast();
                        // 右边比arr[i]小且离得最近的是原栈顶下标链表弹栈的下标
                        rightRecord[index] = i;
                    }
                }

                // 如果栈空或者arr[i]大于栈顶下标链表对应的元素
                if (stack.isEmpty() || arr[i] > arr[stack.peek().getLast()]) {
                    // 压栈
                    LinkedList<Integer> list = new LinkedList<>();
                    list.addFirst(i);
                    stack.push(list);
                }
                // arr[i]等于栈顶下标链表对应的元素
                else {
                    // 追加栈顶下标链表
                    stack.peek().addLast(i);
                }
            }
        }

        // 清算阶段
        while (!stack.isEmpty()) {
            LinkedList<Integer> list = stack.pop();

            // 弹出的是否是栈底下标链表
            if (stack.isEmpty()) {
                for (Integer index : list) {
                    leftRecord[index] = index;
                    rightRecord[index] = index;
                }
                break;
            }

            // 如果不是栈底下标链表
            for (Integer index : list) {
                // 左边比arr[i]小且离得最近的是原栈顶下标链表下面链表的最后一个下标
                leftRecord[index] = stack.peek().getLast();
                // 右边没有比arr[i]小且离得最近的元素
                rightRecord[index] = index;
            }
        }
    }

    // 通过下标获取左边比arr[i]小且离得最近的元素的下标,没有就是自己i
    public int getLeftRecord(int i) {
        return leftRecord[i];
    }

    // 通过下标获取右边比arr[i]小且离得最近的元素的下标,没有就是自己i
    public int getRightRecord(int i) {
        return rightRecord[i];
    }

}

8. 应用

题目:

一个正数数组中所有的数累加起来的和乘以这个数组中的最小值得到的积叫做指标A。

一个数组一定有很多子数组(包括自身),那么每个子数组都会有自己的指标A。

给一个数组,求出在这个数组的子数组中,最大的指标A是多少?

分析:

20211016204555.png

对于数组中每一位元素,我们都需要找出:该元素在子串中是最小的元素且长度最长的子串。

对找出的所有子串计算指标A,得到最大的就是最后的答案。

对于第 i 位元素,如何找出最小元素是 arr[i] 且长度最长的子串?单调栈。

单调栈的具体作用就是从 arr[i] 开始向左右两边找比 arr[i] 小且离得最近的元素,从而确定子串的左右边界。

代码:

public static int getSum(int[] arr, int i, int j) {
    int sum = 0;
    for (int k = i; k <= j; k ++) {
        sum += arr[k];
    }
    return sum;
}

/**
 * 求第cur位的最大指标A
 * @param arr 数组
 * @param i 子串的左边界
 * @param j 子串的右边界
 * @param cur 当前子串中最小值的下标
 * @return
 */
public static int getA(int[] arr, int i, int j, int cur) {
    int sum = 0;

    // arr[cur]不是子串的左边界,也不是右边界
    if (i != cur && j != cur) {
        sum = getSum(arr, i + 1, j - 1);
    }

    // arr[cur]是子串的左边界
    if (i == cur && j != cur) {
        sum = getSum(arr, i, j - 1);
    }

    // arr[cur]是子串的右边界
    if (i != cur && j == cur) {
        sum = getSum(arr, i + 1, j);
    }

    return sum * arr[cur];
}

public static int getMaxA(int[] arr) {
    if (arr == null || arr.length == 0) {
        return -1;
    }

    // 最大指标A
    int maxA = 0;

    // 构建单调队列
    MonotoneStack stack = new MonotoneStack(arr);

    // 求每一位的最大指标A,比出数组中的最大指标A
    for (int i = 0; i < arr.length; i ++) {
        // 获取构建成最大指标A的子串的左右边界
        int left = stack.getLeftRecord(i);
        int right = stack.getRightRecord(i);

        // 获取第i位的最大指标
        int a = getA(arr, left, right, i);

        // 选出arr中的最大指标A
        if (maxA < a) {
            maxA = a;
        }
    }

    return maxA;
}