滚动数组是 DP 中的一种编程思想。简单的理解就是让数组滚动起来,每次都使用固定的几个存储空间,来达到压缩,起到优化空间,节省存储空间的作用。主要应用在递推或动态规划中(如01背包问题)。因为DP题目是一个自底向上的扩展过程,我们常常需要用到的是连续的解,前面的解往往可以舍去。所以用滚动数组优化是很有效的。利用滚动数组的话在 N 很大的情况下可以达到压缩存储的作用。
当然是用时间去换空间的
如在01背包问题中,从理解角度讲我们应开dp[i][j]的二维数组,第一维我们存处理到第几个物品,第二维存储容量,但是我们获得dp[i],只需使用dp[i - 1]的信息,dp[i - k],k>1都成了无用空间,因此我们可以将数组开成一维就行,迭代更新数组中内容。
滚动数组也是这个原理,目的也一样,不过这时候的问题常常是不可能缩成一维的了,比如一个dp[i][j],需要由dp[i - 1 ][k],dp[i - 2][k]决定,i<n,0<k<=10;n <=100000000显然缩不成一维,正常我们应该开一个dp[100000005][11]的数组,结果很明显,超内存,其实我们只要开dp[3][11]。dp[i%3][j]由dp[(i - 1)%3][k]和dp[(i - 2)%3][k]决定,空间复杂度差别巨大。
例子:
- 斐波那契数列([剑指 Offer 10- I. 斐波那契数列])(leetcode.cn/problems/fe…)
答案需要取模 1e9+7(1000000007),如计算初始结果为:1000000008,请返回 1。
class Solution {
public int fib(int n) {
int r[] = new int[3];
if(n == 0)
return 0;
r[1] = 1;
r[2] = 1;
for(int i = 3; i <= n; ++i){
// r[i] = r[i - 1] + r[i - 2];
r[0] = r[1];
r[1] = r[2];
r[2] = r[0] + r[1];
r[2] %= 1000000007;
}
return r[2];
}
}
上面这个循环r[i] = r[i - 1] + r[i - 2];只依赖于前两个数据r[i - 1] 和 r[i - 2]; 为了节约空间用滚动数组的做法,可以将整个dp 数组压缩成 dp[3]。
- 01背包问题
public class T2 {
private static int n;
private static int m;
private static int[] v;
private static int[] w;
private static final Scanner sc = new Scanner(System.in);
/**
* 二维状态表示
*/
public static void two_dimension(){
int [][]opt = new int[n + 1][m + 1];
for (int i = 1; i <= n; i++) {
for (int j = 0; j <= m; j++) {
opt[i][j] = opt[i - 1][j]; //不含 i
if(w[i] <= j){ // 包含 i
opt[i][j] = Math.max(opt[i][j], opt[i - 1][j - w[i]] + v[i]);
}
}
}
System.out.println(opt[n][m]);
}
/**
* 由于opt[i][j] = max(opt[i][j], opt[i - 1][j - w[i]] + v[i]),只和第i行和第i - 1行有关,
* 可以利用滚动数组将二维压缩成一维:
* opt[i][j] = opt[i - 1][j]; ——————> opt[j] = opt[j]; (去掉opt[i]、opt[i-1])
* opt[i][j] = Math.max(opt[i][j], opt[i - 1][j - w[i]] + v[i]); ——————> opt[j] = Math.max(opt[j], opt[j - w[i]] + v[i]);
* 这里直接将opt[i - 1][j - w[i]]的[i - 1]去掉是不正确的,因为此时opt[j - w[i]]等价于opt[i][j - w[i]]去掉[i],而且opt[i - 1][j - w[i]]
* 是已经计算过的,这样会导致+v[i]的结果有问题。所以应该逆向遍历,这样opt[i - 1][j - w[i]]就还未被计算过=0。
*/
public static void one_dimension(){
int []opt = new int[m + 1];
for (int i = 1; i <= n; i++) {
for (int j = m; j >= w[i]; j--) {
opt[j] = Math.max(opt[j], opt[j - w[i]] + v[i]);
}
}
System.out.println(opt[m]);
}
public static void main(String[] args) {
n = sc.nextInt();
m = sc.nextInt();
v= new int[n + 1];
w = new int[n + 1];
for (int i = 1; i <= n; i++) {
w[i] = sc.nextInt();
v[i] = sc.nextInt();
}
two_dimension();
one_dimension();
}
}
\