简单回溯算法oj
回溯算法先画树形图!!!!!
以下三类题模板基本一样,不同之处只是剪枝的方法。回溯算法剪枝要画出树形图,这样就很清楚哪一步该剪枝,依据其特点剪枝。
模板:
void backtrack(parameters) {
if(满足条件){
//get one answer
record answer;
return;
}
for(解空间树的下一个节点){
backtrack(paramter);
}
}
1、全排列
思路:用一个数组来存储已经搜索过的数据,当遍历的起始位置等于数组长度时,说明找到一个排列。添加到结果集并return
public class permutations {
List<List<Integer>> res = new ArrayList<>();
public List<List<Integer>> permute(int[] nums) {
//因为回溯算法最后会还原上一步操作,用原数组来存储每一次的结果
dfs(nums,0,nums.length - 1);
return res;
}
public void dfs(int[] nums,int start,int end) {
if(start == end) { //说明已经走到nums的末尾,找到一个结果
List<Integer> temp = new ArrayList<>();
for(int i = 0; i < nums.length;i++) {
temp.add(nums[i]);
}
res.add(temp);
return;
}
//回溯+递归
for(int i = start; i <= end;i++) {
swap(nums,start,i); //每次交换元素,可以得到两个元素的全排列
//求剩下元素的全排列,如果可以一直递归下去找到结果,那么就会添加结果到结果集中
dfs(nums,start + 1,end);
//不论有没有找到,上面递归跳出后,都会再次交换,将原本交换的两个元素恢复
//即回溯,恢复现场
swap(nums,start,i);
}
}
public void swap(int[] nums,int i,int j) {
int temp = nums[i];
nums[i] = nums[j];
nums[j] = temp;
}
}
dfs函数的第三个参数没有必要,可以通过原数组计算出来
List<List<Integer>> ans = new ArrayList<>();
public List<List<Integer>> permute(int[] nums) {
if(nums.length == 0) return ans;
dfs(0,nums); //不用传数组长度,因为本身传入的数组就是原数组,直接可以计算长度
return ans;
}
public void dfs(int n,int[] array) {
if(n == array.length) { //找到一个解
List<Integer> tmp = new ArrayList<>();
for(int i : array) {
tmp.add(i);
}
ans.add(tmp);
return;
}
for(int i = n; i < array.length;i++) {
swap(n,i,array);
dfs(n + 1, array);
swap(n,i,array);
}
}
public void swap(int i,int j,int[] array) {
int temp = array[i];
array[i] = array[j];
array[j] = temp;
}
思路:原数组有重复数字,需要考虑剪枝。先对原数组进行排序,这样重复的数字就会相邻
public class redundantPermute {
List<List<Integer>> ans = new ArrayList<>();
public List<List<Integer>> permuteUnique(int[] nums) {
if (nums.length == 0) return ans;
Arrays.sort(nums); //先排序
dfs(0,nums);
return ans;
}
private void dfs(int i, int[] nums) {
if (i == nums.length) { //找到一个解
List<Integer> tmp = new ArrayList<>();
for (int n :
nums) {
tmp.add(n);
}
ans.add(tmp);
return;
}
for (int j = i; j < nums.length; j++) {
//剪枝处理
if (check(nums,i,j)){
swap(j,i,nums);
dfs(i + 1,nums);
swap(j,i,nums);
}
}
}
//剪枝函数,判断相邻的两个元素是否相同
private boolean check(int[] nums, int start, int end) {
for (int i = start; i < end; i++) {
if (nums[i] == nums[end]) {
return false;
}
}
return true;
}
private void swap(int j, int i, int[] nums) {
int temp = nums[i];
nums[i] = nums[j];
nums[j] = temp;
}
}
2、组合
思路:回溯时每遍历到一个数就用target减去该数,跳出时加上该数
class Solution {
List<List<Integer>> ans = new ArrayList<>();
public List<List<Integer>> combinationSum(int[] candidates, int target) {
Arrays.sort(candidates);
if (candidates[0] > target){
return ans;
}else if (candidates[0] == target) {
List<Integer> tmp = new ArrayList<>();
tmp.add(candidates[0]);
ans.add(tmp);
return ans;
}
Stack<Integer> stack = new Stack<>(); //存储过程中的结果
dfs(target,candidates,stack);
return ans;
}
public void dfs(int target,int[] array,Stack<Integer> stack) {
if (0 == target) {
List<Integer> temp = new ArrayList<>();
temp.addAll(stack);
ans.add(temp);
return;
}
if (0 < target) { //target要大于0
for (int i = 0; i < array.length; i++) { //元素可以重复使用,每次从0下标开始
if (!stack.isEmpty()) { //同样是剪枝,剪掉重复项;如223和322 ,数组排过序因此223在322之前已经添加到答案中了
if (array[i] < stack.peek()) {
continue;
}
}
if (array[i] > target ) { //排过序,当前数大于目标即可剪枝,剪掉不合理项
return;
}
target -= array[i];
stack.push(array[i]);
dfs(target,array,stack);
stack.pop();
target += array[i];
}
}
}
}
思路:和39题一样,不过在剪枝时还要多判断一次
public class combinationSum2 {
List<List<Integer>> ans = new ArrayList<>();
public List<List<Integer>> combinationSum2(int[] candidates, int target) {
Arrays.sort(candidates); //排序
if (candidates[0] == target) {
List<Integer> tmp = new ArrayList<>();
tmp.add(candidates[0]);
ans.add(tmp);
return ans;
}else if (candidates[0] > target) {
return ans;
}
Stack<Integer> stack = new Stack<>();
dfs(target,candidates,0,stack); //传入起始位置
return ans;
}
public void dfs(int target,int[] array,int index,Stack<Integer> stack) {
if (target == 0) { //找到一个解
List<Integer> temp = new ArrayList<>();
temp.addAll(stack);
ans.add(temp);
return;
}
for (int i = index; i < array.length; i++) { //元素不可以重复使用,每次从index开始
if (check(index,i,array)) { //剪掉重复项
continue;
}
if (array[i] > target) { //剪掉不合理项
return;
}
target -= array[i];
stack.push(array[i]);
dfs(target,array,i+1,stack);
stack.pop();
target += array[i];
}
}
private boolean check(int index, int i, int[] array) {
for (int j = index; j < i; j++) {
if (array[j] == array[i]) {
return true;
}
}
return false;
}
}
思路:这个剪枝比较难想,举个例子
n=7 k = 4,当遍历到5时就没意义了,因为从5开始不够四个数,因此需要剪枝
例如:
当
stack.size()=1
时,还需要3
个数,最后一次合理的解起始位置为5
,其组合为[5,6,7]当
stack.size()=2
时,还需要2
个数,最后一次合理的解起始位置为6
,其组合为[6,7]当
stack.size()=3
时,还需要1
个数,最后一次合理的解起始位置为7
,其组合为[7]可以归纳出:
搜索的上界(最后一次合理解的起始位置) + 剩余要选的元素个数 - 1 = n
那么
搜索的上界 = n - (k - stack.size()) + 1
class Solution {
List<List<Integer>> ans = new ArrayList<>();
public List<List<Integer>> combine(int n, int k) {
if(n == 0 || k == 0) return ans;
Stack<Integer> stack = new Stack<>();
dfs(1,n,k,stack);
return ans;
}
public void dfs(int cur,int n,int k,Stack<Integer> stack) {
if(stack.size() == k) { //找到一个解
List<Integer> temp = new ArrayList<>();
temp.addAll(stack);
ans.add(temp);
return;
}
//确定搜索的上界
for(int i = cur;i <= n - (k - stack.size()) + 1;i++) {
stack.push(i);
dfs(i + 1, n, k ,stack);
stack.pop();
}
}
}
3、子集
思路:回溯即可,注意空集也算子集,因此不需要判断stack的大小
class Solution {
List<List<Integer>> ans = new ArrayList<>();
public List<List<Integer>> subsets(int[] nums) {
if (nums.length == 0) return ans;
Stack<Integer> stack = new Stack<>();
dfs(0,nums,stack);
return ans;
}
private void dfs(int cur, int[] nums, Stack<Integer> stack) {
List<Integer> temp = new ArrayList<>();
temp.addAll(stack);
ans.add(temp); //空集也算子集,不用判断stack大小
for (int i = cur; i <= nums.length - 1; i++) {
stack.push(nums[i]);
dfs(i + 1,nums,stack);
stack.pop();
}
}
}
思路:剪枝方法与之前相同,对元素排序使相同元素相邻
class Solution {
List<List<Integer>> ans = new ArrayList<>();
public List<List<Integer>> subsetsWithDup(int[] nums) {
if(nums.length == 0) return ans;
Stack<Integer> stack = new Stack<>();
Arrays.sort(nums);
dfs(0,nums,stack);
return ans;
}
public void dfs(int cur, int[] nums, Stack<Integer> stack) {
List<Integer> temp = new ArrayList<>();
temp.addAll(stack);
ans.add(temp);
for(int i = cur; i < nums.length;i++) {
if(check(cur,i,nums)) {
continue;
}
stack.push(nums[i]);
dfs(i + 1,nums,stack);
stack.pop();
}
}
public boolean check(int cur,int end,int[] nums) {
for(int i = cur;i < end;i++) {
if(nums[i] == nums[end]) {
return true;
}
}
return false;
}
}