背包问题详解

316 阅读8分钟

大家好,最近由于刚刚入职要做的事情很多,疏于更新一段时间,从今天开始,我会慢慢恢复更新,与大家分享一些算法方面的经验。

好久没说动态规划了,经过上次的分析,大家应该已经对动态规划有了个大体的认识,今天我们一起来看一个经典的问题--0/1背包问题。可能有些同学觉得背包问题很简单,无非写个判断条件,递归执行就能解决。但是想要拿到最优解,我们仍然有许多需要细细思量的东西。

我们先来看一下题目的定义:给定N种水果的重量跟收益,我们需要把它们放进一个可容重量为C的背包里,使得包里的水果在总重量不超过C的同时拥有最高的收益,假设水果数量有限,一种只能选一个

题目很短,很容易理解,我们再具体化一点,看一个例子。假设我现在要去卖水果,现在的情况如下: 水果: { 苹果, 橙子, 香蕉, 西瓜 } 重量: { 2, 3, 1, 4 } 收益: { 4, 5, 3, 7 } 背包可容重量: 5

先来试试不同的组合的结果: 苹果 + 橙子 (总重量5) => 9 苹果 + 香蕉 (总重量 3) => 7 橙子 + 香蕉 (总重量 4) => 8 香蕉 + 西瓜 (总重量 5) => 10

我们可以看到西瓜跟香蕉是绝配,在有限的重量限制下给我们最大的收益。我们来尝试用算法把它描述出来。如我前面所说,最简单的就是暴力递归,每次遇到一种水果,我们只有两个选择,要么在背包还放得下它的时候把它放进去,要么就直接不放它,这样就能帮我们列举出所有的情形,然后我们只取收益最大的那种。

private int knapsackRecursive(int[] profits, int[] weights, int capacity, int currentIndex) {
        if (capacity <= 0 || currentIndex >= profits.length)
            return 0;

        // 在当前元素可以被放进背包的情况下递归的处理剩余元素
        int profit1 = 0;
        if( weights[currentIndex] <= capacity )
            profit1 = profits[currentIndex] + knapsackRecursive(profits, weights,
                    capacity - weights[currentIndex], currentIndex + 1);

        // 跳过当前元素处理剩余元素
        int profit2 = knapsackRecursive(profits, weights, capacity, currentIndex + 1);

        return Math.max(profit1, profit2);
    }

这样的解法时间复杂度得在O(2^n),数据量稍微一大就会出现明显的耗时。

递归调用
我们可以画一下递归调用的树,由于重量跟收益数组是一成不变的,对我们的算法设计过程没有影响,每次可能变化的只有剩余可用重量跟代表当前处理到哪个元素的索引,从这张递归调用树更加可以确定暴力递归数据越大时更耗时,同时也揭露了有些场景被多次重复计算。哈,这就轮到我们缓存大法出场了!由于只有重量跟索引在处理过程中变化,那我们可以用一个二维数组来存储已经计算的结果。这个过程不必详述,直接上代码:

private int knapsackRecursive(Integer[][] dp, int[] profits, int[] weights, int capacity,
                                  int currentIndex) {

        if (capacity <= 0 || currentIndex >= profits.length)
            return 0;

        // 如果已经算得结果,直接返回
        if (dp[currentIndex][capacity] != null)
            return dp[currentIndex][capacity];

        // 在当前元素可以被放进背包的情况下递归的处理剩余元素
        int profit1 = 0;
        if (weights[currentIndex] <= capacity)
            profit1 = profits[currentIndex] + knapsackRecursive(dp, profits, weights,
                    capacity - weights[currentIndex], currentIndex + 1);

        // 跳过当前元素处理剩余元素
        int profit2 = knapsackRecursive(dp, profits, weights, capacity, currentIndex + 1);

        dp[currentIndex][capacity] = Math.max(profit1, profit2);
        return dp[currentIndex][capacity];
    }

好啦,最终所有的结果都存储在这个二维数组里面,我们可以确定我们不会有超过NC个子问题,N是元素的数量,C是背包可容重量,也就是说,到这儿我们时间空间复杂度都只有O(NC)了。

事情到这里还没有结束,我们来尝试用自下而上的方法来考虑这道题,来看看能不能获得更优解。本质上,我们想在上面的递归过程中,对于每一个索引,每一个剩余的可容重量,我们都想在这一步获得可以的最大收益。处理第3个元素时,我们想获得能拿到的最大收益。处理第4个元素时,我们还是想获得可以拿到的最大收益。(毕竟获取最大利润是每个人的目标嘛)dp[i][c]就代表从最开始i=0时计算到当前i的最大收益。那每次我们也只有两种选择:

  1. 跳过当前元素,那处理到这儿我们只能拿到前面过程中最大收益dp[i-1][c]
  2. 只要重量能放得下,放入这个元素,那么这时候的最大收益就为profit[i] + dp[i-1][c-weight[i]]

最终我们想要获得的最大收益就是这俩中的最大值。dp[i][c] = max (dp[i-1][c], profit[i] + dp[i-1][c-weight[i]])

 public int solveKnapsack(int[] profits, int[] weights, int capacity) {
        if (capacity <= 0 || profits.length == 0 || weights.length != profits.length)
            return 0;

        int n = profits.length;
        int[][] dp = new int[n][capacity + 1];

        // 0空间就0收益
        for(int i=0; i < n; i++)
            dp[i][0] = 0;

        // 在处理第一个元素时,只要它重量可以被背包容下,那肯定放入比不放入收益高
        for(int c=0; c <= capacity; c++) {
            if(weights[0] <= c)
                dp[0][c] = profits[0];
        }

        // 循环处理所有元素所有重量
        for(int i=1; i < n; i++) {
            for(int c=1; c <= capacity; c++) {
                int profit1= 0, profit2 = 0;
                // 包含当前元素
                if(weights[i] <= c)
                    profit1 = profits[i] + dp[i-1][c-weights[i]];
                // 不包含当前元素
                profit2 = dp[i-1][c];
                // 取最大值
                dp[i][c] = Math.max(profit1, profit2);
            }
        }

        // dp的最后一个元素就是最大值
        return dp[n-1][capacity];
    }

这样时间空间复杂度也都在O(N*C)。

那怎么找到选择的元素呢?其实很简单,我们之前说过,不选中当前元素的话,当前的最大收益就是处理前一个元素时的最大收益,换言之,只要在dp里的上下俩元素相同的,那那个索引所代表的元素肯定没被选中,dp里第一个不同的总收益所在的位置就是选中的元素所在的位置。

private void printSelectedElements(int dp[][], int[] weights, int[] profits, int capacity) {
        System.out.print("Selected weights:");
        int totalProfit = dp[weights.length - 1][capacity];
        for (int i = weights.length - 1; i > 0; i--) {
            if (totalProfit != dp[i - 1][capacity]) {
                System.out.print(" " + weights[i]);
                capacity -= weights[i];
                totalProfit -= profits[i];
            }
        }

        if (totalProfit != 0)
            System.out.print(" " + weights[0]);
        System.out.println("");
    }

这个算法够简单吧?但我觉得还是不能就这么结束了,我们大费周章地换了一种思路来解题,取得同样的复杂度就结束了吗,我们再来观察下我们这个算法。我们发现我们在处理当前元素时,我们需要的仅仅是在前一个元素时各个索引最大的收益,再往前的数据我们根本不关心,那这就是一个优化的点,我们可以把dp的size大幅缩减。

static int solveKnapsack(int[] profits, int[] weights, int capacity) {
        if (capacity <= 0 || profits.length == 0 || weights.length != profits.length)
            return 0;

        int n = profits.length;
        // 我们只需要前面一次的结果来获得最优解,因此我们可以把数组缩减成两行
        // 我们用 `i%2` 代替`i` 跟 `(i-1)%2` 代替`i-1`
        int[][] dp = new int[2][capacity+1];

        // 在处理第一个元素时,只要它重量可以被背包容下,那肯定放入比不放入收益高
        for(int c=0; c <= capacity; c++) {
            if(weights[0] <= c)
                dp[0][c] = dp[1][c] = profits[0];
        }

        // 循环处理所有元素所有重量
        for(int i=1; i < n; i++) {
            for(int c=0; c <= capacity; c++) {
                int profit1= 0, profit2 = 0;
                // 包含当前元素
                if(weights[i] <= c)
                    profit1 = profits[i] + dp[(i-1)%2][c-weights[i]];
                // 不包含当前元素
                profit2 = dp[(i-1)%2][c];
                // 取最大值
                dp[i%2][c] = Math.max(profit1, profit2);
            }
        }

        return dp[(n-1)%2][capacity];
    }

这时候空间复杂度就只剩下O(N)了,嘿嘿,这是比较让人满意的结果了。不过要是同学们再丧心病狂一点,再变态一点,再观察一下我们的算法,可以发现其实我们只需要前面一次结果中的两个值dp[c]dp[c-weight[i]]。那我们可不可以把结果都放在一个一维数组里面,来看看:

  1. 当我们访问dp[i]的时候,它还没被当前迭代的结果覆盖掉,可用!
  2. 当我们访问dp[c-weight[i]]的时候,如果weight[i]>0,那么dp[c-weight[i]]是有可能已经被覆盖掉了。

这并不是什么难题,只要我们改变处理顺序就好了:c:capacity-->0。从后往前处理,就能保证我们在修改dp里面任何值得时候,这个被修改的值都用不到了,大家想想,是不是这么个道理。 思路想明白了,那手写代码就很简单了:

static int solveKnapsack(int[] profits, int[] weights, int capacity) {
        if (capacity == 0 || profits.length == 0 || weights.length != profits.length) {
            return 0;
        }
        int n = profits.length;
        int[] dp = new int[capacity + 1];
        for (int i = 1; i <= capacity; i++) {
            if (weights[0] <= i) {
                dp[i] = profits[0];
            }
        }
        for (int j = 1; j < n; j++) {
            for (int c = capacity; c >= 0; c--) {
                int profit1 = 0;
                if (weights[j] <= c) {
                    profit1 = profits[j] + dp[c - weights[j]];
                }
                int profit2 = dp[c];
                dp[c] = Math.max(profit1, profit2);
            }
        }
        return dp[capacity];
    }

现在我们的算法可以说是最优咯!最后大家再来好好地总结下,其实动态规划就是想办法减少不必要的内存消耗,跟复用之前问题的结果来解决现在的问题以用最少的时间解决问题。思路就是这么简单,但是关于内存优化,这就得靠经验的积累了,大家多加练习做手熟了就好啦。