理解回溯算法——回溯算法的初学者指南

0 前言

最近做了不少关于回溯法的算法题,积累了一些心得,这篇博文算是对回溯法的一个小总结。

1 回溯法简介

回溯法简单来说就是按照深度优先的顺序,穷举所有可能性的算法,但是回溯算法比暴力穷举法更高明的地方就是回溯算法可以随时判断当前状态是否符合问题的条件。一旦不符合条件,那么就退回到上一个状态,省去了继续往下探索的时间。
最基本的回溯法是在解空间中穷举所有的解。比如求序列[1,2,3]的全排列,那么我们可以画出一颗解空间树。
解空间树
回溯法的特点是深度优先遍历,也就是该问题的遍历顺序是1->2->3,然后从子节点3返回,从子节点2返回,再到1->3->2,以此类推。
状态的返回只有当前的节点不再满足问题的条件或者我们已经找到了问题的一个解时,才会返回,否则会以深度优先一直在解空间树内遍历下去。
当然,对于某些问题如果其解空间过大,即使用回溯法进行计算也有很高的时间复杂度,因为回溯法会尝试解空间树中所有的分支。所以根据这类问题,我们有一些优化剪枝策略以及启发式搜索策略。
所谓优化剪枝策略,就是判断当前的分支树是否符合问题的条件,如果当前分支树不符合条件,那么就不再遍历这个分支里的所有路径。
所谓启发式搜索策略指的是,给回溯法搜索子节点的顺序设定一个优先级,从该子节点往下遍历更有可能找到问题的解。

2 回溯函数的组成

1.回溯出口,当找到了一个问题的解时,存储该解。
2.回溯主体,就是遍历当前的状态的所有子节点,并判断下一个状态是否是满足问题条件的,如果满足问题条件,那么进入下一个状态。
3.状态返回,如果当前状态不满足条件,那么返回到前一个状态。

def backtrack(current_statement) -> bool:
	if condition is satisfy:
		solution = current_statement
		return True
	else:
		for diff_posibility in current_statement:
			next_statement = diff_posibility
			if next_statement is satisfy condition:
				if backtrack(next_statement):
					return True
				else:
					back to current_statement
		return False

3 简单的回溯函数

3.1 问题描述

给定一个不包含重复数字的序列,返回所有不重复的全排列。

3.1.1 问题分析

遍历所有的解空间树即可找到答案。
首先定义一个回溯函数

# combination 为当前的状态
backtrack(combination=[])

那么它的出口部分也很好写,就是当combination的长度等于序列的长度时,就找到了问题的一个解。

if len(combination) == len(nums):
       answer.append(combination)

然后是回溯函数的主体部分,我们要遍历当前状态下的所有子节点,并判断子节点是否还符合问题的条件,那么对于这个问题,因为全排列的数是不能够重复的,所以我们的判断方式是当前的数没有包含在combination中,那么进入下一个状态。

for num in nums:
    if num not in combination:
        backtrack(combination+[num])

那么这个问题需要返回上一个状态吗?答案是不需要,因为backtrack的下一个状态的写法是backtrack(combination + [num]),这并不会改变我们当前的combination的值,因为我们没有对combination对象进行一个重新的赋值操作。
如果说修改一下回溯函数的主体。

for num in nums:
    if num not in combination:
    	combination.append(num)
        backtrack(combination+[num])

那么这时候,combination的值被改变了,所以需要写一个返回上一个状态的代码。

for num in nums:
  if num not in combination:
      combination.append(num)
      backtrack(combination)
      combination.pop()

并且,因为我们传入的是相当于是combination对象,所以在存储解的时候需要深拷贝。

if combination.__len__() == nums.__len__():
    solution = copy.deepcopy(combination)
    answer.append(solution)

3.1.2 完整代码

import copy
class Solution:
    def permute(self, nums: list):
        answer = []
        def backtrack(combination=[]):
            if combination.__len__() == nums.__len__():
                solution = copy.deepcopy(combination)
                answer.append(solution)
                return
            for num in nums:
                if num not in combination:
                    combination.append(num)
                    backtrack(combination)
                    combination.pop()
        backtrack()
        return answer

3.2 问题描述

给定一个包含重复数字的序列,返回所有不重复的全排列。

3.2.1 问题分析

相对于第一个问题,这个问题稍微加了点难度,也就是序列中包含了重复的数字。由于有重复数字的关系,我们也就不能够只简单的判断一下某个数是否在combination中。我们可以构建一个hash表,来记录当前状态的hash键值。

hash_num = {
   
   }
for item in nums:
    hash_num[item] = hash_num.get(item,0) + 1

在回溯函数中,我们用hash表来判断是否可以将当前的数字加入到combination中。

def backtrack(combination:list=[],hash_num:dict=hash_num):
     if len(combination) == len(nums):
         output.append(combination)
     else:
         for num_key in list(hash_num.keys()):
             hash_num[num_key] = hash_num[num_key] - 1
             if hash_num[num_key] == 0:
                 hash_num.pop(num_key)
             backtrack(combination + [num_key],hash_num)
             hash_num[num_key] = hash_num.get(num_key,0) + 1

如果当前的数字在hash表中对应的值是1,那么进入到下一个状态之前,我们要删掉这个hash_key。
之后要注意把这个hash_table恢复回原来的状态。

3.2.2 完整代码

class Solution:
    def permuteUnique(self, nums: list) -> list:
        hash_num = {
   
   }
        for item in nums:
            hash_num[item] = hash_num.get(item,0) + 1
        output = []

        def backtrack(combination:list=[],hash_num:dict=hash_num):
            if len(combination) == len(nums):
                output.append(combination)
            else:
                for num_key in list(hash_num.keys()):
                    hash_num[num_key] = hash_num[num_key] - 1
                    if hash_num[num_key] == 0:
                        hash_num.pop(num_key)
                    backtrack(combination + [num_key],hash_num)
                    hash_num[num_key] = hash_num.get(num_key,0) + 1
        backtrack()
        return output

4 剪枝对于回溯函数的重要性

像是对于某些问题,如果要搜索全部的解空间的话,范围太大,如果能提前根据问题的特征排除某些不必要搜索的子空间,将大大的提高搜索效率。

4.1 问题描述

给定一个无重复元素的数组 candidates 和一个目标数 target ,找出 candidates 中所有可以使数字和为 target 的组合。
candidates 中的数字可以无限制重复被选取,并且所有数字(包括 target)都是正整数,解集不能包含重复的组合。
在某种意义上,如果对这个问题不进行剪枝,那这个问题的搜索空间是无限的。

4.1.1 问题分析

首先,因为candidates是可以重复选择的,所以在每个状态下,都有len(candidates)个子节点。
首先将candidates进行升序排序。
我们从根结点出发,选择一个子节点之后,加上将这个子节点的值,作为下一个状态,在遍历的过程中对遇到的值进行累加。
如果在某个状态下:
1.我们发现此时的状态代表的数值等于我们的target,那么它的右兄弟结点以及以它为根结点的子树不再进行探索。
2.我们发现此时的状态代表的数值要小于我们的target,那么继续进行探索。
3.我们发现此时的状态代表的数值要大于我们的target,那么它的右兄弟结点以及以它为根结点的子树不再进行探索。
剪枝的过程
如上图所示,我们对遇到的结点进行累加,如果发现有一个结点的值是我们target的值,因为candidates的值是按照升序排序的,并且candidates的数值不可能重复,那么它的右兄弟结点的状态值只能够大于target。candidates中的值只包含正数,因此,以它为根结点的子树下的所有结点的状态值都会大于target,因此这些结点我们都没有必要进行探索了。
同理,如果我们发现当前结点的状态值要大于我们的target值,那么其右兄弟结点的状态值,以及以它为根节点的子树下的所有结点的状态值都会大于target,因此这些结点也是没有必要探索的。
根据这些限制条件,我们可以大大的缩小我们搜索的子空间,提高问题解答的效率。

4.1.2 完整代码

class Solution:
    def combinationSum(self, candidates: list, target: int) -> list:
        candidates = sorted(candidates)
        answer = []
        def backtrack(current_sum:int=0,current_list:list=[]):
            if current_sum == target:
                if sorted(current_list) not in answer:
                    answer.append(current_list)
            else:
                for number in candidates:
                    if current_sum + number > target:
                        break
                    else:
                        backtrack(current_sum+number,current_list+[number])
        backtrack()
        return answer

4.2 问题描述

给出集合 [1,2,3,…,n],其所有元素共有 n! 种排列。
按大小顺序列出所有排列情况,并一一标记,当 n = 3 时, 所有排列如下:
“123”
“132”
“213”
“231”
“312”
“321”
给定 n 和 k,返回第 k 个排列,其中给定 n 的范围是 [1, 9],给定 k 的范围是[1, n!]。

4.2 问题分析

这个问题是能够更加的突显剪枝的重要性,如果不对问题进行剪枝,我们也可以很容易的对问题进行求解:对找到的解进行计数,当找到第k个解时,停止回溯算法,返回结果即可。

代码1-无剪枝算法
import time
class Solution:
    def getPermutation(self, n: int, k: int) -> str:
        number_list = []
        for i in range(n):
            number_list.append(str(i+1))
        answer = []
        count = [0]
        def backtrack(combination,k):
            if count[0] == k:
                return None
            if len(combination) == len(number_list):
                temp = count[0]
                count[0] = temp + 1
                if count[0] == k:
                    answer.append(combination)
                return None
            else:
                for number in number_list:
                    if number not in combination:
                        backtrack(combination + number,k)

        backtrack("",k)
        return answer[0]

if __name__ == '__main__':
    start_time = time.time()
    print(Solution().getPermutation(9,362880))
    end_time = ti
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值