引言
在贪心算法一章中,我们看到对于许多问题,贪心策略未能提供最优解。在这些问题中,有一些可以通过使用分治(Divide and Conquer, D&C)技术轻松解决。分治是一种基于递归的重要算法设计技术。
分治算法通过递归地将问题分解为两个或多个相同类型的子问题,直到这些子问题简单到可以直接解决。然后,将子问题的解合并,以给出原始问题的解。
什么是分治策略?
分治策略通过以下步骤解决问题:
-
分解:将问题分解为本身是同一类型问题的较小实例的子问题。
-
递归:递归地解决这些子问题。
-
合并:适当地组合它们的答案。
分治算法是否总是有效?
并非所有问题都可以用分治技术解决。根据分治的定义,递归解决的是相同类型的子问题。对于所有问题,并非总是能找到相同大小的子问题,因此分治并非适用于所有问题。
分治可视化
为了更好地理解,考虑以下可视化。假设 n 是原始问题的大小。如上所述,我们可以看到问题被分解为每个大小为 n/b(对于某个常数 b)的子问题。我们递归地解决子问题,并将它们的解合并,以得到原始问题的解。
理解分治算法
为了更清晰地理解分治算法,让我们考虑一个故事。有一位年老的富农,他有七个儿子。他担心自己去世后,他的土地和财产会被七个儿子瓜分,他们之间会争吵不休。
于是他召集他们,给他们看了七根绑在一起的棍子,并告诉他们,谁能折断这捆棍子,谁就能继承一切。他们都试了,但没有人能折断这捆棍子。然后,老人解开这捆棍子,一根一根地折断了。兄弟们决定他们应该团结一致,共同努力,共同成功。对于问题解决者来说,寓意不同。如果我们不能解决一个问题,就将其分解为部分,一次解决一个部分。
在前面的章节中,我们已经根据分治策略解决了许多问题:如二分查找、归并排序、快速排序等……参考这些主题,以了解分治是如何工作的。以下是其他一些可以用分治策略轻松解决的实时问题。对于所有这些问题,我们都可以找到与原始问题相似的子问题。
-
在电话簿中查找名字:我们有一本按字母顺序排列名字的电话簿。给定一个名字,我们如何确定这个名字是否在电话簿中?
-
将石头破碎成粉末:我们想把一块石头变成粉末(非常小的石头)。
-
在酒店中找到出口:我们位于一个非常长的酒店大厅的尽头,有一系列长长的门,其中一扇门就在我们旁边。我们正在寻找通往出口的门。
-
在停车场找到我们的车。
分治算法的优点
解决难题:分治是一种解决难题的强大方法。例如,考虑汉诺塔问题。这需要将问题分解为子问题,解决平凡情况,并将子问题组合起来解决原始问题。将问题分解为可以再次组合的子问题是设计新算法的主要难点。对于许多这样的问题,分治提供了一个简单的解决方案。
并行性:由于分治允许我们独立解决子问题,这使得在多处理器机器上执行成为可能,特别是在共享内存系统中,处理器之间的数据通信不需要提前规划,因为不同的子问题可以在不同的处理器上执行。
内存访问:分治算法自然倾向于高效利用内存缓存。这是因为一旦子问题足够小,其所有子问题都可以在缓存中解决,无需访问较慢的主内存。
分治算法的缺点
分治方法的一个缺点是递归速度慢。这是因为重复子问题调用的开销。此外,分治方法需要栈来存储调用(递归中每个点的状态)。实际上,这取决于实现风格。对于许多问题,只要有足够大的递归基本情况,递归的开销就可以变得微不足道。
分治的另一个问题是,对于某些问题,它可能比迭代方法更复杂。例如,要添加 n 个数字,一个简单的循环按顺序将它们相加,比分治方法更容易,后者将数字集分成两半,递归地将它们相加,然后将总和相加。
主定理
如上所述,在分治方法中,我们递归地解决子问题。所有问题通常都用递归定义来定义。这些递归问题可以很容易地用主定理解决。有关主定理的详细信息,请参阅算法分析导论章节。为了连贯性,让我们重新考虑主定理。
如果递归的形式为
其中:
-
T(n)
是要解决的问题的大小为n
时的时间复杂度。 -
a
是递归中产生的子问题的数量。 -
n/b
是每个子问题的大小,其中b > 1
。 -
f(n)
是将问题分解和合并子问题解所需的时间。
主定理根据 与
的比较,将递归关系分为三种情况:
情况 1: 

如果存在一个常数 ε > 0
,使得 f(n)
的增长速度慢于
,那么:
情况 2: 
如果 f(n)
的增长速度与
相同,那么:
情况 3: 
如果存在一个常数 ε > 0
,使得 f(n)
的增长速度快于 ,并且满足正则性条件(即
a * f(n/b) ≤ c * f(n)
对于某个常数 c < 1
和所有足够大的 n
),那么:
例子
考虑递归关系
。
-
这里,
a = 2
,b = 2
,所以 -
,所以
根据情况 2,我们有:
因此,递归算法主定理提供了一种快速确定许多分治算法时间复杂度的方法。
分治算法的应用
1)二分查找
问题描述
给定一个排序数组和一个目标值,编写一个算法来查找目标值在数组中的索引。如果目标值不存在于数组中,返回 -1。
解决方案
二分查找是一种高效的查找算法,时间复杂度为 O(logn)。它通过不断将数组分成两半,逐步缩小搜索范围,直到找到目标值或搜索范围为空。
示例代码
def binary_search(arr, target):
low = 0
high = len(arr) - 1
while low <= high:
mid = (low + high) // 2
if arr[mid] == target:
return mid
elif arr[mid] < target:
low = mid + 1
else:
high = mid - 1
return -1
# 示例输入输出
arr = [1, 3, 5, 7, 9, 11, 13, 15, 17, 19]
target = 9
result = binary_search(arr, target)
print(f"目标值{target}的索引为:{result}")
2)归并排序和快速排序
问题描述
归并排序和快速排序是两种常用的高效排序算法。归并排序的时间复杂度为 O(nlogn),快速排序的平均时间复杂度为 O(nlogn),但在最坏情况下可能退化为 O(n^2)。
解决方案
归并排序通过递归地将数组分成两半,分别排序,然后合并两个已排序的子数组。快速排序通过选择一个基准值,将数组分成两部分,一部分小于基准值,另一部分大于基准值,然后递归地排序这两部分。
示例代码
def merge_sort(arr):
if len(arr) <= 1:
return arr
mid = len(arr) // 2
left = merge_sort(arr[:mid])
right = merge_sort(arr[mid:])
return merge(left, right)
def merge(left, right):
result = []
i, j = 0, 0
while i < len(left) and j < len(right):
if left[i] < right[j]:
result.append(left[i])
i += 1
else:
result.append(right[j])
j += 1
result.extend(left[i:])
result.extend(right[j:])
return result
def quick_sort(arr):
if len(arr) <= 1:
return arr
pivot = arr[len(arr) // 2]
left = [x for x in arr if x < pivot]
middle = [x for x in arr if x == pivot]
right = [x for x in arr if x > pivot]
return quick_sort(left) + middle + quick_sort(right)
# 示例输入输出
arr = [3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5]
print(f"Original array: {arr}")
print(f"Merge sorted array: {merge_sort(arr)}")
print(f"Quick sorted array: {quick_sort(arr)}")
3)中位数查找
问题描述
给定两个排序数组,编写一个算法来查找这两个数组并集中的中位数。
解决方案
使用分治法,时间复杂度为 O(logn)。通过不断缩小搜索范围,找到两个数组的中位数。
示例代码
def find_median_sorted_arrays(nums1, nums2):
m, n = len(nums1), len(nums2)
if m > n:
nums1, nums2, m, n = nums2, nums1, n, m
imin, imax, half_len = 0, m, (m + n + 1) // 2
while imin <= imax:
i = (imin + imax) // 2
j = half_len - i
if i < m and nums2[j-1] > nums1[i]:
imin = i + 1
elif i > 0 and nums1[i-1] > nums2[j]:
imax = i - 1
else:
if i == 0: max_of_left = nums2[j-1]
elif j == 0: max_of_left = nums1[i-1]
else: max_of_left = max(nums1[i-1], nums2[j-1])
if (m + n) % 2 == 1:
return max_of_left
if i == m: min_of_right = nums2[j]
elif j == n: min_of_right = nums1[i]
else: min_of_right = min(nums1[i], nums2[j])
return (max_of_left + min_of_right) / 2.0
# 示例输入输出
nums1 = [1, 3]
nums2 = [2]
print(f"Median: {find_median_sorted_arrays(nums1, nums2)}")
4)最小值和最大值查找
问题描述
给定一个数组,编写一个算法来查找数组中的最小值和最大值。
解决方案
通过一次遍历数组,同时记录最小值和最大值,时间复杂度为 O(n)。
示例代码
def find_min_max(arr):
if not arr:
return None, None
min_val = max_val = arr[0]
for num in arr:
if num < min_val:
min_val = num
if num > max_val:
max_val = num
return min_val, max_val
# 示例输入输出
arr = [3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5]
min_val, max_val = find_min_max(arr)
print(f"Array: {arr}, Min: {min_val}, Max: {max_val}")
5)矩阵乘法
问题描述
给定两个 n×n 矩阵 A 和 B,计算 n×n 矩阵 C = A × B。
解决方案
使用 Strassen 算法,时间复杂度为 O(nlog27)。通过将矩阵分成四个子矩阵,递归地计算 7 个矩阵乘法,然后合并结果。
示例代码
def strassen_matrix_multiply(A, B):
n = len(A)
if n == 1:
return [[A[0][0] * B[0][0]]]
mid = n // 2
A11 = [row[:mid] for row in A[:mid]]
A12 = [row[mid:] for row in A[:mid]]
A21 = [row[:mid] for row in A[mid:]]
A22 = [row[mid:] for row in A[mid:]]
B11 = [row[:mid] for row in B[:mid]]
B12 = [row[mid:] for row in B[:mid]]
B21 = [row[:mid] for row in B[mid:]]
B22 = [row[mid:] for row in B[mid:]]
M1 = strassen_matrix_multiply(add_matrices(A11, A22), add_matrices(B11, B22))
M2 = strassen_matrix_multiply(add_matrices(A21, A22), B11)
M3 = strassen_matrix_multiply(A11, subtract_matrices(B12, B22))
M4 = strassen_matrix_multiply(A22, subtract_matrices(B21, B11))
M5 = strassen_matrix_multiply(add_matrices(A11, A12), B22)
M6 = strassen_matrix_multiply(subtract_matrices(A21, A11), add_matrices(B11, B12))
M7 = strassen_matrix_multiply(subtract_matrices(A12, A22), add_matrices(B21, B22))
C11 = add_matrices(subtract_matrices(add_matrices(M1, M4), M5), M7)
C12 = add_matrices(M3, M5)
C21 = add_matrices(M2, M4)
C22 = subtract_matrices(add_matrices(add_matrices(M1, M3), M6), M2)
C = [[0] * n for _ in range(n)]
for i in range(mid):
for j in range(mid):
C[i][j] = C11[i][j]
C[i][j + mid] = C12[i][j]
C[i + mid][j] = C21[i][j]
C[i + mid][j + mid] = C22[i][j]
return C
def add_matrices(A, B):
n = len(A)
return [[A[i][j] + B[i][j] for j in range(n)] for i in range(n)]
def subtract_matrices(A, B):
n = len(A)
return [[A[i][j] - B[i][j] for j in range(n)] for i in range(n)]
# 示例输入输出
A = [[1, 2], [3, 4]]
B = [[5, 6], [7, 8]]
C = strassen_matrix_multiply(A, B)
print(f"Matrix A: {A}")
print(f"Matrix B: {B}")
print(f"Matrix C = A × B: {C}")
6)最近点对问题
问题描述
给定一个平面上的点集,找出其中距离最近的两个点。
解决方案
使用分治法,时间复杂度为 O(nlogn)。将点集按 x 坐标排序,然后递归地将点集分成两半,分别在每半中寻找最近点对,最后合并结果。
示例代码
import math
def distance(p1, p2):
return math.sqrt((p1[0] - p2[0])**2 + (p1[1] - p2[1])**2)
def brute_force(points):
min_dist = float('inf')
for i in range(len(points)):
for j in range(i + 1, len(points)):
min_dist = min(min_dist, distance(points[i], points[j]))
return min_dist
def strip_closest(strip, d):
min_dist = d
strip.sort(key=lambda point: point[1])
for i in range(len(strip)):
for j in range(i + 1, len(strip)):
if (strip[j][1] - strip[i][1]) >= min_dist:
break
min_dist = min(min_dist, distance(strip[i], strip[j]))
return min_dist
def closest_util(points):
if len(points) <= 3:
return brute_force(points)
mid = len(points) // 2
mid_point = points[mid]
dl = closest_util(points[:mid])
dr = closest_util(points[mid:])
d = min(dl, dr)
strip = []
for point in points:
if abs(point[0] - mid_point[0]) < d:
strip.append(point)
return min(d, strip_closest(strip, d))
def closest(points):
points.sort(key=lambda point: point[0])
return closest_util(points)
# 示例输入输出
points = [(2, 3), (12, 30), (40, 50), (5, 1), (3, 4), (6, 8)]
min_distance = closest(points)
print(f"Points: {points}")
print(f"Minimum distance between any two points: {min_distance}")
扩展阅读:更多问题与解决方案
问题 1
考虑一个算法 A,它通过将问题分成五个大小为一半的子问题,递归地解决每个子问题,然后在线性时间内合并解决方案。该算法的复杂度是多少?
解决方案
假设输入大小为 n,T(n) 定义为问题的解决方案。根据描述,算法将问题分成 5 个子问题,每个子问题的大小为 n/2。因此,我们需要解决 5 个子问题。解决这些子问题后,算法在 O(n) 时间内扫描给定数组以合并解决方案。该问题的总递归算法可以表示为: T(n)=5T(n/2)+O(n) 使用主定理(Master theorem),我们得到复杂度为 。
示例代码
def divide_and_conquer(n):
if n <= 1:
return 1
else:
# 递归解决 5 个子问题
subproblems = [divide_and_conquer(n // 2) for _ in range(5)]
# 合并解决方案
result = sum(subproblems) + n
return result
# 示例输入输出
n = 16
print(f"Input size: {n}, Result: {divide_and_conquer(n)}")
问题 2
类似问题 1,算法 B 通过递归解决两个大小为 n−1 的子问题,然后在常数时间内合并解决方案。该算法的复杂度是多少?
解决方案
假设输入大小为 n,T(n) 定义为问题的解决方案。根据描述,算法将问题分成 2 个子问题,每个子问题的大小为 n−1。因此,我们需要解决 2 个子问题。解决这些子问题后,算法在常数时间内合并解决方案。该问题的总递归算法可以表示为: T(n)=2T(n−1)+O(1) 使用主定理(Subtract and Conquer),我们得到复杂度为 。
示例代码
def divide_and_conquer(n):
if n <= 1:
return 1
else:
# 递归解决 2 个子问题
subproblems = [divide_and_conquer(n - 1) for _ in range(2)]
# 合并解决方案
result = sum(subproblems) + 1
return result
# 示例输入输出
n = 5
print(f"Input size: {n}, Result: {divide_and_conquer(n)}")
问题 3
类似问题 1,算法 C 通过将问题分成九个大小为 3n 的子问题,递归地解决每个子问题,然后在 O(n2) 时间内合并解决方案。该算法的复杂度是多少?
解决方案
假设输入大小为 n,T(n) 定义为问题的解决方案。根据描述,算法将问题分成 9 个子问题,每个子问题的大小为 n/3。因此,我们需要解决 9 个子问题。解决这些子问题后,算法在 O(n^2) 时间内合并解决方案。该问题的总递归算法可以表示为: T(n)=9T(n/3)+O(n^2) 使用主定理(Master theorem),我们得到复杂度为 。
示例代码
def divide_and_conquer(n):
if n <= 1:
return 1
else:
# 递归解决 9 个子问题
subproblems = [divide_and_conquer(n // 3) for _ in range(9)]
# 合并解决方案
result = sum(subproblems) + n * n
return result
# 示例输入输出
n = 27
print(f"Input size: {n}, Result: {divide_and_conquer(n)}")
问题 4
编写一个递归算法并求解。
解决方案
假设输入大小为 n,T(n) 定义为问题的解决方案。根据给定代码,算法在打印字符后将问题分成 2 个子问题,每个子问题的大小为 n/2。因此,我们需要解决 2 个子问题。解决这些子问题后,算法不进行任何合并操作。该问题的总递归算法可以表示为: T(n)=2T(n/2)+O(1) 使用主定理(Master theorem),我们得到复杂度为 。
示例代码
def divide_and_conquer(n):
if n <= 1:
print("Character")
return 1
else:
# 递归解决 2 个子问题
subproblems = [divide_and_conquer(n // 2) for _ in range(2)]
return sum(subproblems)
# 示例输入输出
n = 8
print(f"Input size: {n}, Result: {divide_and_conquer(n)}")
问题 5
给定一个数组,编写一个算法来查找最大值和最小值。
解决方案
参考选择算法章节。
问题 6:二分查找
讨论二分查找及其复杂度。
解决方案
参考搜索章节。
分析
假设输入大小为 n,T(n) 定义为问题的解决方案。元素是有序的。在二分查找中,我们取中间元素并检查要搜索的元素是否等于该元素。如果相等,则返回该元素。如果要搜索的元素大于中间元素,则我们考虑右子数组并丢弃左子数组。反之亦然。这意味着在两种情况下,我们都丢弃了一半的子数组,只考虑剩余的一半。此外,在每次迭代中,我们将元素分成两个相等的部分。根据上述讨论,每次我们将问题分成 2 个子问题,每个子问题的大小为 n/2,并解决其中一个子问题。该问题的总递归算法可以表示为: T(n)=T(n/2)+O(1) 使用主定理(Master theorem),我们得到复杂度为 O(logn)。
示例代码
def binary_search(arr, target):
left, right = 0, len(arr) - 1
while left <= right:
mid = (left + right) // 2
if arr[mid] == target:
return mid
elif arr[mid] < target:
left = mid + 1
else:
right = mid - 1
return -1
# 示例输入输出
arr = [1, 2, 3, 4, 5, 6, 7, 8, 9]
target = 5
print(f"Array: {arr}, Target: {target}, Result: {binary_search(arr, target)}")
问题 7:三元查找
考虑二分查找的修改版本。假设数组被分成 3 个相等的部分(三元查找)而不是 2 个相等的部分。编写该三元查找的递归关系并求解其复杂度。
解决方案
从问题 5 的讨论中,二分查找的递归关系为: T(n)=T(n/2)+O(1) 类似地,我们将 2 替换为 3,表示我们将数组分成 3 个相等大小的子数组,并只考虑其中一个。因此,三元查找的递归关系可以表示为: T(n)=T(n/3)+O(1) 使用主定理(Master theorem),我们得到复杂度为 O(logn)(我们不需要担心对数的底数,因为它们是常数)。
示例代码
def ternary_search(arr, target):
left, right = 0, len(arr) - 1
while left <= right:
mid1 = left + (right - left) // 3
mid2 = right - (right - left) // 3
if arr[mid1] == target:
return mid1
if arr[mid2] == target:
return mid2
if target < arr[mid1]:
right = mid1 - 1
elif target > arr[mid2]:
left = mid2 + 1
else:
left = mid1 + 1
right = mid2 - 1
return -1
# 示例输入输出
arr = [1, 2, 3, 4, 5, 6, 7, 8, 9]
target = 5
print(f"Array: {arr}, Target: {target}, Result: {ternary_search(arr, target)}")
问题 8:不等分查找
在问题 5 中,如果我们把数组分成两个大小大约为三分之一和三分之二的部分,会怎样?
解决方案
我们考虑三元查找的一个稍微修改的版本,其中只进行一次比较,创建两个分区,一个大约有 n/3 个元素,另一个有 2n/3 个元素。在这种情况下,最坏情况是递归调用在较大的元素部分。因此,最坏情况的递归关系为: T(n)=T(2n/3)+O(1) 使用主定理(Master theorem),我们得到复杂度为 O(logn)。值得注意的是,对于一般的 k-ary 查找(只要 k 是一个不依赖于 n 的固定常数),当 n 趋向于无穷大时,我们也会得到相同的结果。
示例代码
def modified_ternary_search(arr, target):
left, right = 0, len(arr) - 1
while left <= right:
mid = left + (right - left) // 3
if arr[mid] == target:
return mid
if target < arr[mid]:
right = mid - 1
else:
left = mid + 1
return -1
# 示例输入输出
arr = [1, 2, 3, 4, 5, 6, 7, 8, 9]
target = 5
print(f"Array: {arr}, Target: {target}, Result: {modified_ternary_search(arr, target)}")
问题 9
讨论归并排序及其复杂度。
解决方案
参考排序章节。
问题 10
讨论快速排序及其复杂度。
解决方案
参考排序章节。
总结
分治法是一种通过递归地将问题分解成较小的子问题,解决这些子问题,然后合并它们的解来解决原问题的算法策略。