HJ103 Redraiment的走法

117 阅读1分钟

题目介绍

题目链接:www.nowcoder.com/practice/24…

image.png

分析

相信大家对数学归纳法都不陌生,高中就学过,而且思路很简单。比如我们想证明一个数学结论,那么我们先假设这个结论在k < n时成立,然后根据这个假设,想办法推导证明出k = n的时候此结论也成立。如果能够证明出来,那么就说明这个结论对于k等于任何数都成立。

类似的,我们设计动态规划算法,不是需要一个 dp 数组吗?我们可以假设dp[0...i-1]都已经被算出来了,然后问自己:怎么通过这些结果算出dp[i]

直接拿最长递增子序列这个问题举例你就明白了。不过,首先要定义清楚 dp 数组的含义,即dp[i]的值到底代表着什么?

我们的定义是这样的:dp[i]表示以nums[i]这个数结尾的最长递增子序列的长度。 根据这个定义,我们就可以推出 base case:dp[i]初始值为 1,因为以nums[i]结尾的最长递增子序列起码要包含它自己。

举两个例子:

image.png

根据这个定义,我们的最终结果(子序列的最大长度)应该是 dp 数组中的最大值。

int res = 0;
for (int i = 0; i < dp.length; i++) {
    res = Math.max(res, dp[i]);
}
return res;

读者也许会问,刚才的算法演进过程中每个dp[i]的结果是我们肉眼看出来的,我们应该怎么设计算法逻辑来正确计算每个dp[i]呢?

这就是动态规划的重头戏,如何设计算法逻辑进行状态转移,才能正确运行呢?这里需要使用数学归纳的思想:

假设我们已经知道了dp[0..4]的所有结果,我们如何通过这些已知结果推出dp[5]

image.png

根据刚才我们对dp数组的定义,现在想求dp[5]的值,也就是想求以nums[5]为结尾的最长递增子序列。

nums[5] = 3,既然是递增子序列,我们只要找到前面那些结尾比 3 小的子序列,然后把 3 接到这些子序列末尾,就可以形成一个新的递增子序列,而且这个新的子序列长度加一

nums[5]前面有哪些元素小于nums[5]?这个好算,用 for 循环比较一波就能把这些元素找出来。

以这些元素为结尾的最长递增子序列的长度是多少?回顾一下我们对dp数组的定义,它记录的正是以每个元素为末尾的最长递增子序列的长度。

以我们举的例子来说,nums[0]nums[4]都是小于nums[5]的,然后对比dp[0]dp[4]的值,我们让nums[5]和更长的递增子序列结合,得出dp[5] = 3

image.png

for (int j = 0; j < i; j++) {
    if (nums[i] > nums[j]) {
        dp[i] = Math.max(dp[i], dp[j] + 1);
    }
}

i = 5时,这段代码的逻辑就可以算出dp[5]。其实到这里,这道算法题我们就基本做完了。

读者也许会问,我们刚才只是算了dp[5]呀,dp[4],dp[3]这些怎么算呢?类似数学归纳法,你已经可以算出dp[5]了,其他的就都可以算出来:

for (int i = 0; i < nums.length; i++) {
    for (int j = 0; j < i; j++) {
        // 寻找 nums[0..j-1] 中比 nums[i] 小的元素
        if (nums[i] > nums[j]) {
            // 把 nums[i] 接在后面,即可形成长度为 dp[j] + 1,
            // 且以 nums[i] 为结尾的递增子序列
            dp[i] = Math.max(dp[i], dp[j] + 1);
        }
    }
}

结合我们刚才说的 base case,下面我们看一下完整代码:

import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.io.IOException;
import java.util.*;
public class Main {
    public static void main(String[] args) throws IOException{
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        int[] nums = new int[n];
        int res = 0;
        //初始化数组
        for(int i = 0; i < n; i++) {
            nums[i] = sc.nextInt();
        }
        int[] dp = new int[n];
        //初始化dp数组
        Arrays.fill(dp, 1);
        for(int i = 0; i < n; i++) {
            for(int j = 0; j < i; j++) {
                if(nums[i] > nums[j]) {
                    //更新dp[i]的值
                    dp[i] = Math.max(dp[i], dp[j]+1);
                }
            }
            //更新最大递增子序列的个数
            res = Math.max(res, dp[i]);
        }
        System.out.println(res);
    } 
}