递归技巧之记忆化搜索

39 阅读3分钟

今天在做算法题目337. 打家劫舍 III - 力扣(LeetCode)的时候学到了一个提高递归效率的方法,很受用。

因为递归写法中包含很多重复计算。

而有些时候,我们需要通过对递归函数返回值做判断或者是累加,这个时候就可以用一个容器来记录递归的返回结果,这样可以省去重复计算的时间,叫记忆化递归。

接下来就通过这道题目来了解记忆化搜索的思路。

题目要求: 小偷又发现了一个新的可行窃的地区。这个地区只有一个入口,我们称之为 root 。

除了 root 之外,每栋房子有且只有一个“父“房子与之相连。一番侦察之后,聪明的小偷意识到“这个地方的所有房屋的排列类似于一棵二叉树”。 如果 两个直接相连的房子在同一天晚上被打劫 ,房屋将自动报警。

给定二叉树的 root 。返回 在不触动警报的情况下 ,小偷能够盗取的最高金额 。

这个题目的暴力解法如下:

class Solution {
public:
    int rob(TreeNode* root) {
        //差不多相当于后序遍历吧,左右中
        if(root == nullptr)//情况一,空节点
        return 0;
        if(root->left == nullptr && root->right == nullptr)//情况二,只有一个节点
        return root->val;

        //第三种情况,root至少有一个子树
        int max_value = 0;
        if(root->left)//如果左子树不为空
        {
            max_value += rob(root->left->left);
            max_value += rob(root->left->right);
        }
        if(root->right)
        {
            max_value += rob(root->right->left);
            max_value += rob(root->right->right);
        }
        max_value += root->val;//max_value存放的是抢劫当前root节点的情况下,能抢劫的最大值。
        return max(rob(root->left) + rob(root->right),max_value);
    }
};

显然会超时,就是因为多了很多重复计算。计算左右孩子节点的时候其实已经把孙子节点计算过了,但是后来还是单独计算了一遍孙子节点。

所以,我们要想办法记录下每个节点对应的返回结果,这样一来当再次遇到该节点的时候,就不用去做重复计算了。

记录所采用的容器首选unordered_map。

class Solution {
public:
    unordered_map<TreeNode*,int>val;//val保存的就是每个节点对应的最大抢劫金额
    int rob(TreeNode* root) {
        //差不多相当于后序遍历吧,左右中
        //这个写法超时了,因为包含很多重复计算,计算左右孩子节点的时候其实已经把孙子节点计算过了,但是后来还是单独计算二零一遍孙子节点
        //所以超时了。对于这种需要用递归函数返回值做判断或者是做累加的情况,就可以用一个容器来记录递归的返回结果,这样可以省去重复计算的时间,叫记忆化递归
        
        if(val.find(root) != val.end())//说明之前已经计算过了,不要再重复计算了
        return val[root];
        //接下来的都是没有计算过的情况
        if(root == nullptr)//情况一,空节点
        {
            val[root] = 0;
            return 0;
        }
        if(root->left == nullptr && root->right == nullptr)//情况二,只有一个节点
        {
            val[root] = root->val;
            return root->val;
        }

        //第三种情况,root至少有一个子树
        int max_value = 0;
        if(root->left)//如果左子树不为空
        {
            max_value += rob(root->left->left);
            max_value += rob(root->left->right);
        }
        if(root->right)
        {
            max_value += rob(root->right->left);
            max_value += rob(root->right->right);
        }
        max_value += root->val;//max_value存放的是抢劫当前root节点的情况下,能抢劫的最大值。
        val[root] = max(rob(root->left) + rob(root->right),max_value);//把最大值保存下来
        return val[root];
    }
};