JZ51 数组中的逆序对

63 阅读2分钟

leetcode.cn/problems/sh…

在数组中的两个数字,如果前面一个数字大于后面的数字,则这两个数字组成一个逆序对。输入一个数组,求出这个数组中的逆序对的总数。

 

示例 1:

输入: [7,5,6,4] 输出: 5  

限制:

0 <= 数组长度 <= 50000

解题思路:

直观来看,使用暴力统计法即可,即遍历数组的所有数字对并统计逆序对数量。此方法时间复杂度为 O(N2) ,观察题目给定的数组长度范围 0≤N≤50000 ,可知此复杂度是不能接受的。

「归并排序」与「逆序对」是息息相关的。归并排序体现了 “分而治之” 的算法思想,具体为:

  • 分: 不断将数组从中点位置划分开(即二分法),将整个数组的排序问题转化为子数组的排序问题;
  • 治: 划分到子数组长度为 1 时,开始向上合并,不断将 较短排序数组 合并为 较长排序数组,直至合并至原数组时完成排序;

如下图所示,为数组 [7,3,2,6,0,1,5,4] 的归并排序过程。

image.png

合并阶段 本质上是 合并两个排序数组 的过程,而每当遇到 左子数组当前元素 > 右子数组当前元素 时,意味着 「左子数组当前元素 至 末尾元素」 与 「右子数组当前元素」 构成了若干 「逆序对」。

因此,考虑在归并排序的合并阶段统计「逆序对」数量,完成归并排序时,也随之完成所有逆序对的统计。

算法流程:

merge_sort() 归并排序与逆序对统计:

  1. 终止条件: 当 l≥r 时,代表子数组长度为 1 ,此时终止划分;

  2. 递归划分: 计算数组中点 m ,递归划分左子数组 merge_sort(l, m) 和右子数组 merge_sort(m + 1, r) ;

  3. 合并与逆序对统计:

    1. 暂存数组 nums 闭区间 [i,r] 内的元素至辅助数组 tmp ;
    2. 循环合并: 设置双指针 i , j 分别指向左 / 右子数组的首元素;
    • 当 i=m+1 时: 代表左子数组已合并完,因此添加右子数组当前元素 tmp[j] ,并执行 j=j+1 ;
    • 否则,当 j=r+1 时: 代表右子数组已合并完,因此添加左子数组当前元素 tmp[i] ,并执行 i=i+1 ;
    • 否则,当 tmp[i]≤tmp[j] 时: 添加左子数组当前元素 tmp[i] ,并执行 i=i+1;
    • 否则(即 tmp[i]>tmp[j])时: 添加右子数组当前元素 tmp[j] ,并执行 j=j+1 ;此时构成 m−i+1 个「逆序对」,统计添加至 res ;
  4. 返回值: 返回直至目前的逆序对总数 res ;

reversePairs() 主函数:

  1. 初始化: 辅助数组 tmp ,用于合并阶段暂存元素;

  2. 返回值: 执行归并排序 merge_sort() ,并返回逆序对总数即可;

如下图所示,为数组 [7,3,2,6,0,1,5,4] 的归并排序与逆序对统计过程。

image.png

复杂度分析:

时间复杂度 O(NlogN) : 其中 N 为数组长度;归并排序使用 O(NlogN) 时间;

空间复杂度 O(N) : 辅助数组 tmp 占用 O(N) 大小的额外空间;

代码

class Solution {
    int result = 0;

    public int reversePairs(int[] nums) {
        mSort(nums, 0, nums.length - 1);
        return result;
    }

    public void mSort(int[] nums, int low, int high) {
        int mid = (low + high) / 2;
        if (low < high) {
            mSort(nums, low, mid);
            mSort(nums, mid + 1, high);
            merge(nums, low, mid, high);
        }
    }

    public void merge(int[] nums, int low, int mid, int high) {
        int[] temp = new int[high - low + 1];
        int i = low;
        int j = mid + 1;
        int k = 0;

        while (i <= mid && j <= high) {
            if (nums[i] <= nums[j]) {
                temp[k++] = nums[i++];
            } else {
                // 关键
                result += mid + 1 - i;
                temp[k++] = nums[j++];
            }
        }

        while (i <= mid) {
            temp[k++] = nums[i++];
        }
        while (j <= high) {
            temp[k++] = nums[j++];
        }

        for (int n = 0; n < temp.length; n++) {
            nums[n + low] = temp[n];
        }
    }
}