【算法】相邻两天不能选同一课程 —— DP 优化复盘

6 阅读3分钟

n 天课程安排,每天有 m 节可选课程,第 i 天选择第 j 节课可获得权值 a[i][j]
要求每天都选一节课,且相邻两天不能选同一个课程编号,求最大总权值。


一、初始思路

这题很自然能想到动态规划。

设:

dp[i][j]

表示:

  • 到第 i 天为止
  • i 天选择第 j 节课
  • 所能获得的最大总权值

那么转移就是:

dp[i][j] = a[i][j] + max(dp[i - 1][k])  (k != j)

意思是:

  • 当前天选第 j 节课
  • 前一天只能从所有 k != j 的课程里选一个最优的接上

初始状态:

dp[0][j] = a[0][j]

最终答案:

max(dp[n - 1][j])

二、我一开始的问题

1)复杂度太高

朴素写法中,计算每个 dp[i][j] 时,都要重新扫一遍前一行找最大值:

for (int i = 1; i < n; i++) {
    for (int j = 0; j < m; j++) {
        for (int k = 0; k < m; k++) {
            if (k == j) continue;
            ...
        }
    }
}

时间复杂度是:

O(n * m * m)

也就是 O(nm²)

如果 m 较大,就容易超时。


2)没注意总和可能爆 int

这题每一天都会累加一个权值,如果:

  • n 比较大
  • 单个 a[i][j] 也不小

那么总和可能超出 int 范围,所以更稳妥的写法应该用:

long long

尤其是:

  • dp
  • ans
  • 中间最大值变量

都尽量用 long long


三、优化关键

观察朴素转移:

dp[i][j] = a[i][j] + max(dp[i - 1][k])  (k != j)

对于前一行 dp[i - 1][*],我们真正反复在问的是:

“除去当前列 j 之外,前一行的最大值是多少?”

其实没必要每次都重新扫一遍。

只要预处理出前一行的:

  • 最大值 max1
  • 最大值所在位置 idx1
  • 次大值 max2

那么:

  • 如果 j != idx1,说明当前列不和最大值冲突,直接接 max1
  • 如果 j == idx1,说明最大值这列不能用,只能接 max2

于是每一层只需要两次线性扫描:

  1. 扫一遍前一行,求 max1 / max2 / idx1
  2. 扫一遍当前行,完成转移

这样总复杂度就降成:

O(n * m)

四、优化后的转移

设前一行最优信息为:

max1, max2, idx1

那么当前行:

if (j == idx1) dp[i][j] = a[i][j] + max2;
else           dp[i][j] = a[i][j] + max1;

这就是整道题最核心的优化点。


五、复杂度分析

朴素做法

O(nm²)

优化后

O(nm)

空间

如果保留二维 dp,空间是:

O(nm)

如果再用滚动数组优化,只保留前一行和当前行,可以压成:

O(m)

六、最终代码

#include <iostream>
#include <vector>
#include <algorithm>
#include <climits>
using namespace std;

int main() {
    int n, m;
    cin >> n >> m;

    vector<vector<long long>> a(n, vector<long long>(m));
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < m; j++) {
            cin >> a[i][j];
        }
    }

    vector<long long> prev(m), cur(m);

    for (int j = 0; j < m; j++) {
        prev[j] = a[0][j];
    }

    for (int i = 1; i < n; i++) {
        long long max1 = LLONG_MIN, max2 = LLONG_MIN;
        int idx1 = -1;

        for (int j = 0; j < m; j++) {
            if (prev[j] > max1) {
                max2 = max1;
                max1 = prev[j];
                idx1 = j;
            } else if (prev[j] > max2) {
                max2 = prev[j];
            }
        }

        for (int j = 0; j < m; j++) {
            if (j == idx1) cur[j] = a[i][j] + max2;
            else cur[j] = a[i][j] + max1;
        }

        prev = cur;
    }

    long long ans = LLONG_MIN;
    for (int j = 0; j < m; j++) {
        ans = max(ans, prev[j]);
    }

    cout << ans;
    return 0;
}

七、这题复盘时我最该记住的点

这题不难想到 DP,真正要注意的是两个地方:

1. 朴素 DP 会超时

因为每个状态都重新扫一遍前一行,复杂度是 O(nm²)

2. 总和要注意 long long

单个权值不一定大,但累加很多次后可能超 int

所以这题的核心收获不是“会不会写 DP”,而是:

  • 能不能发现朴素转移里的重复计算
  • 能不能用“最大值 + 次大值”把复杂度降下来
  • 能不能注意总和类型要开 long long