算法小课堂——最小生成树Kruskal
前言
hello,大家好吖,上周我们讲解了 Floyd算法,不知大家还有没有印象嘞~,今天我们来讲解一下最小生成树,简单又实用的策略优化算法
算法原理
最小生成树(Kruskal 算法)适用于求解连通所有节点所需花费最小的问题。最形象的例子就是: 有n个村子,每个村子与其他村子修筑连通的道路需要花费一定金额,怎么样在保证所有村子连通的情况下使花费最少呢?
最小生成树本质是贪心,从耗费最少的边开始,不断加入集合,直到所有节点都连通。
首先,我们将所有的边按照权重(消耗值)进行从小到大排序,这里我们需要一个集合来存储最后选出的满足条件的边。然后遍历排序后的边,如果起点终点不在一个集合,则将终点加入起点的集合中,如果已经在一个集合里,说明前面用更少的消耗(之前从小到大排列的作用)达到了相同的效果,则不采用这条边。
可以发现,这个算法过程主要是连通分量的查询和合并。需要知道任意两个节点是不是在同一个连通分量中,根据需要还要合并两个连通分量。这里说的连通分量可以认为是边的集合,不唯一。
存储连通分量我们可以想到链表或者邻接矩阵,查询连通我们可以使用DFS或者BFS,但是是对于图结构而言,而且这种思路会比较复杂点,查询上效率不高。 坦白说就是费劲
这里一般教科书推荐是并查集(Union-Find Set),可以把每个连通分量看成一个集合,该集合包含了在这个连通分量的所有节点。而我们不需要去考虑连通的方式,比如顺序或者连通的形式,只需要考虑它是不是我的崽。查询的过程简单了,合并的过程利用集合的基本操作就可以完成,整个过程就可以比较简洁。
下面我们用例子来演示一下:
比如我们通过查询边已经得到三个不同的连通分量如下:
然后如果橙子集合中的任意一个节点有连接蓝色集合里的任意一个节点的边(查询是不是同一个连通分量),且为当前下一步的最优,则接下来的连通分量会进行合并:
算法的终止条件可以是遍历完所有的边,也可以是当连通分量有n-1个时(n个节点的图其无环连通边的个数为n-1)。既然要耗费最少,那自然是没有多余的边。
代码实现
首先我们需要实现并查集的类及相应的方法。
class Union_find_set:
"""
the implementation of union_find_set
"""
def __init__(self, identity, subset):
# identity: 标号,即所属连通分量的标签
# subset: 当前集合的数据
self.identity = identity
self.set = set()
if isinstance(subset, set):
self.set = subset
elif isinstance(subset, list) or isinstance(subset, tuple):
for item in subset:
self.set.add(item)
else:
raise ValueError("Subset should be list or set or tuple type")
# 获取当前连通分量的标签
def get_id(self):
return self.identity
# 设置当前连通分量的标签
def set_id(self, value):
self.identity = value
return
# 获取当前连通分量的节点数据
def get_set(self):
return (i for i in self.set)
# 设置当前连通分量的节点数据
def set_set(self, s):
self.set = s
return
# 返回当前连通分量与其他连通分量的并集
def union(self, other):
return self.set.union(other.set)
然后我们来看看Kruskal算法的实现(以开头作为解释的问题为例:有n个村子m条待修的路,每条路连通两个村子,且耗费一定金额。求使所有村子互通的最少金额)
def Mini_spanning_tree():
"""
implement Mini_spanning_tree using Union-Find set
"""
# 数据读入
n, k = list(map(int, input().split()))
data = []
for i in range(k):
data.append(list(map(int, input().split())))
data = sorted(data, key=lambda x: x[2])
# 总耗费
tot = 0
# 选出的边集合
path = []
s = dict()
# 初始化: 节点各自为营
for i in range(n+1):
s[i] = Union_find_set(i, [i])
# 遍历
for d in data:
u, v, w = d
# 编号不同且路径小于n-1,连通情况下路径数量等于节点数-1
if s[u].get_id() != s[v].get_id() and len(path) < n - 1:
# 修改节点v所在连通分量里所有节点的标签号
for i in s[v].get_set():
s[i].set_id(s[u].get_id())
# 获取节点u和节点v各自所在的连通分量的并集
temp = s