算法:回溯算法秒杀所有递归枚举子序列问题

537 阅读3分钟

递归枚举子序列的通用模板:

vector<vector<int>> ans;//子序列组合
vector<int> temp;//当前子序列
void dfs(int cur,vector<int> &nums){
    if(cur == nums.size()){
//判断是否合法,如果合法判断是否重复,将满足条件的加入答案
        if (isValid()&&notVisited()){
            ans.push_back(temp);
        }
        return;
    }
//选择当前元素
    temp.push_back(nums[cur]);
    dfs(cur+1,nums);
    temp.pop_back();
//不选择当前元素
    dfs(cur+1,nums);
}

这是一个递归枚举子序列的通用模板,即用一个临时数组 temp来保存当前选出的子序列,使用 cur来表示当前位置的下标,在 dfs(cur, nums) 开始之前,[0,cur−1] 这个区间内的所有元素都已经被考虑过,而 [cur,n] 这个区间内的元素还未被考虑。在执行 dfs(cur, nums) 时,我们考虑 cur 这个位置选或者不选,如果选择当前元素,那么把当前元素加入到 temp中,然后递归下一个位置,在递归结束后,应当把 temp 的最后一个元素删除进行回溯;如果不选当前的元素,直接递归下一个位置。

上述通用模板基本可以嵌套所有子序列递归枚举问题。下面举几个实例。

问题1:递增子序列(leetcode 491)

给定一个整型数组, 你的任务是找到所有该数组的递增子序列,递增子序列的长度至少是2。

** 示例: **

** 输入: **

[4, 6, 7, 7] 

**输出: **

[[4, 6], [4, 7], [4, 6, 7], [4, 6, 7, 7], [6, 7], [6, 7, 7], [7,7], [4,7,7]]

说明: 给定数组的长度不会超过15。 数组中的整数范围是 [-100,100]。 给定数组中可能包含重复数字,相等的数字应该被视为递增的一种情况。

解答:  在套用通用模板的基础上,我们需要判断组合的合法性和去重,针对这道题合法性考虑两点,一是temp.size()>=2,二是nums[cur] >= last 一个是子序列长度大于等于2,另一个是当前选择的元素应该大于等于上一个选择的元素。

而判断重复需要考虑nums[cur]==last的情况,也就是当前元素等于上一个元素的情况。这里拿[4,6,7,7]举例,假设现在我们已经选了[4,6],此时对于剩下两个7一共有4种情况考虑:

  1. 前者和后者都选
  2. 前者后者都不选
  3. 选前者不选后者
  4. 选后者不选前者

很容易发现3,4这两种情况会产生重复[4,6,7]所以要避免这种情况,这里我们只取第四种情况,舍弃第三种情况。所以代码实现中我们限制不选择的条件为nums[cur]!=last,迫使只要选择了前者就必须一起选择后者,避免第三种情况。

具体代码如下:

ector<vector<int>> ans;
vector<int> temp;
void dfs(int cur,int last,vector<int> &nums){
    if(cur == nums.size()){
        if (temp.size()>=2){
            ans.push_back(temp);
        }
        return;
    }
    if(nums[cur]>=last){
        temp.push_back(nums[cur]);
        dfs(cur+1,nums[cur],nums);
        temp.pop_back();
    }
    if(nums[cur]!=last)
        dfs(cur+1,last,nums);
}

问题2:组合总和(leetcode216)

找出所有相加之和为 n 的 k 个数的组合。组合中只允许含有 1 - 9 的正整数,并且每种组合中不存在重复的数字。 

 说明: 所有数字都是正整数。 解集不能包含重复的组合。 

 **示例 1: **

输入: k = 3, n = 7 

输出: [[1,2,4]] 

**示例 2: **

输入: k = 3, n = 9 

输出: [[1,2,6], [1,3,5], [2,3,4]]

解答:

先上代码为敬

class Solution {
    vector<vector<int>> ans;
    vector<int> temp;
public:
    vector<vector<int>> combinationSum3(int k, int n) {
        if (n>(1+9)*9/2)
            return ans;
        dfs(1,0,k,n);
        return ans;
    }

    void dfs(int cur,int sum,int k,int n){
        if (temp.size() == k){
            if (sum == n)
                ans.push_back(temp);
            return;
        }
        if(sum>n||cur>9) return;

        temp.push_back(cur);
        dfs(cur+1,sum+cur,k,n);
        temp.pop_back();

        dfs(cur+1,sum,k,n);
    }
};