之前一直对回溯理解不深,今天刷一下leetcode的回溯专题。不得不说,大佬们写的递归回溯代码真的优美,我i了。
17题
思路树状图:
public class code17 {
//一个映射表,第二个位置是"abc“,第三个位置是"def"。。。
//这里也可以用map,用数组可以更节省点内存
String[] letter_map = {" ","*","abc","def","ghi","jkl","mno","pqrs","tuv","wxyz"};
List<String> res = new ArrayList<>();
public List<String> letterCombinations(String digits){
if(digits==null||"".equals(digits)) return res;
dfs(0,digits,"");
return res;
}
//tmp:当前生成的字符串组合
private void dfs(int pos, String digits,String tmp) {
if(pos==digits.length()){
//说明遍历到底了,准备回溯
this.res.add(tmp);
return;
}
char c = digits.charAt(pos);
int curr = c - '0';
String map_str = letter_map[curr];
//遍历字符串,比如第一次得到的是2,也就是遍历"abc"
for(int i=0;i<map_str.length();i++)
dfs(pos+1,digits,tmp+map_str.charAt(i));
}
}
93题(和91题很像)
我们要知道IP的格式,每位是在0~255之间,
注意: 不能出现以0开头的两位以上数字,比如012,08…
public class code93 {
public List<String> restoreIpAddresses(String s) {
List<String> ans = new ArrayList<>();
dfs(0,s,new ArrayList<String>(),ans);
return ans;
}
//pos当前扫到s的位置,curr暂存的ip片段
private void dfs(int pos, String s, ArrayList<String> curr, List<String> ans) {
if(curr.size()==4){
if(pos==s.length()){
//把curr中的ip片段拼接成ip字符串
ans.add(String.join(".", curr));
}
return ;
}
for(int i=1;i<=3;i++){
if (pos+i > s.length()) break;
String segment = s.substring(pos,pos+i);
// 剪枝条件:不能以0开头,不能大于255
if (segment.startsWith("0") && segment.length() > 1 || (i == 3 && Integer.parseInt(segment) > 255)) continue;
curr.add(segment);
dfs(pos+i,s,curr,ans);
//回溯
curr.remove(curr.size()-1);
}
}
}
39题
思路:
做到这里,已经慢慢掌握了回溯的框架。
题解图给的非常清晰了:
public class code39 {
List<List<Integer>> res = new ArrayList<>();
int target;
public List<List<Integer>> combinationSum(int[] candidates, int target) {
if(candidates==null)
return res;
Arrays.sort(candidates);
this.target = target;
dfs(0,new ArrayList<>(),candidates);
return res;
}
//每一个满足条件的tmp就是一个子集
private void dfs(int pos, List<Integer> tmp, int[] candidates) {
if(sum(tmp) == this.target){
//Java 中可变对象是引用传递,因此需要将当前 path 里的值拷贝出来
this.res.add(new ArrayList<>(tmp));
return;
}
for (int i=pos;i<candidates.length&&sum(tmp)<=target;i++){
tmp.add(candidates[i]);
dfs(i,tmp,candidates);
tmp.remove(tmp.size()-1);
}
return;
}
private int sum(List<Integer> tmp){
int sum = 0;
for (int a :tmp ) {
sum += a;
}
return sum;
}
}
40题
40题和39题非常像,不同之处在于:
- candidate中的数只能用一次
- candidate中的数有重复数
遇到的问题:例如:candidates = [10,1,2,7,6,1,5], target = 8。满足条件的包括[1,1,6],因为有两个1所以可以使用两次,但是当递归到第二个1时,就需要剪枝了。这里的判断条件让我纠结了很久。
因为数组是已经排序过的,这里i>pos表明此时是回溯过来的,若和前一个相同则需要剪枝。如果这里写i>0,那么[1,1,6]里的1也只能用一次。
public class code40 {
List<List<Integer>> res = new ArrayList<>();
int target;
public List<List<Integer>> combinationSum2(int[] candidates, int target) {
if(candidates==null)
return res;
Arrays.sort(candidates);
this.target = target;
dfs(0,new ArrayList<>(),candidates);
return res;
}
//每一个满足条件的tmp就是一个子集
private void dfs(int pos, List<Integer> tmp, int[] candidates) {
if(sum(tmp) == this.target){
//Java 中可变对象是引用传递,因此需要将当前 path 里的值拷贝出来
this.res.add(new ArrayList<>(tmp));
return;
}
for (int i=pos;i<candidates.length&&sum(tmp)<=target;i++){
if(i>pos){//剪枝:如candidate[1,1,2,5,6],当递归到第二个1时,进行剪枝
//这个判断条件一开始写成了i>0
if(candidates[i]==candidates[i-1]) continue;
}
tmp.add(candidates[i]);
dfs(i+1,tmp,candidates);
tmp.remove(tmp.size()-1);
}
return;
}
private int sum(List<Integer> tmp){
int sum = 0;
for (int a :tmp ) {
sum += a;
}
return sum;
}
}
46题
迎刃而解!
注意一下剪枝条件
public class code46 {
List<List<Integer>> res = new ArrayList<>();
public List<List<Integer>> permute(int[] nums) {
if(nums==null) return res;
Arrays.sort(nums);
dfs(nums,new ArrayList<>());
return res;
}
private void dfs(int[] nums,List<Integer> tmp) {
if(tmp.size()==nums.length)
{
res.add(new ArrayList<>(tmp));
return;
}
for (int i=0;i<nums.length;i++){
if(tmp.contains(nums[i])) continue;
tmp.add(nums[i]);
dfs(nums,tmp);
tmp.remove(tmp.size()-1);
}
}
}
题47
这题卡在剪枝这里了,参考了大佬的代码之后:
public class code47 {
List<List<Integer>> res = new ArrayList<>();
boolean[] used;
public List<List<Integer>> permuteUnique(int[] nums) {
if(nums==null) return res;
used = new boolean[nums.length];
Arrays.sort(nums);
dfs(nums,new ArrayList<>(),used);
return res;
}
private void dfs(int[] nums,List<Integer> tmp,boolean[] used) {
if(tmp.size()==nums.length)
{
res.add(new ArrayList<>(tmp));
return;
}
for (int i=0;i<nums.length;i++){
if(used[i]) continue;
//在 used[i - 1] 刚刚被撤销的时候剪枝,说明接下来会被选择,搜索一定会重复,故"剪枝"
if (i > 0 && nums[i - 1] == nums[i] && !used[i - 1]) {
continue;
}
used[i] = true;
tmp.add(nums[i]);
dfs(nums,tmp,used);
tmp.remove(tmp.size()-1);
used[i] = false;
}
}
}