Spark GraphX 学习笔记——Dijstra最短路径算法

1. Scala中的Dijstra最短路径算法

import org.apache.spark.graphx._
def dijkstra[VD](g:Graph[VD,Double], origin:VertexId): Graph[(VD,Double), Double] = {
	
	/**
	 * 1. 初始化
	 * 遍历图的所有节点
	 * 变为(false, Double.MaxValue的形式,后者是初始化的距离)
	 * 如果是origin节点,则变为0
	 */
	var g2 = g.mapVertices(
		(vid,vd) => (false, if (vid == origin) 0 else Double.MaxValue))
	

	/**
	 * 2. 遍历所有的点,找到最短路径的点,并作为当前顶点
	 */
	for (i <- 1L to g.vertices.count-1) {
	val currentVertexId =
		g2.vertices.filter(!_._2._1)
			.fold((0L,(false,Double.MaxValue)))((a,b) =>
				if (a._2._2 < b._2._2) a else b)
			._1

	// 3. 向与当前顶点相邻的顶点发消息,再聚合消息:取小值作为最短路径
	val newDistances: VertexRDD[Double] = g2.aggregateMessages[Double](

		// sendMsg: 向邻边发送消息,内容为边的距离与最短路径值之和
		ctx => if (ctx.srcId == currentVertexId)
			ctx.sendToDst(ctx.srcAttr._2 + ctx.attr),
		// mergeMsg: 选择较小的值为当前顶点的相邻顶点的最短路径值
		(a,b) => math.min(a,b))

	// 4. 生成结果图
	g2 = g2.outerJoinVertices(newDistances)((vid, vd, newSum) =>
		(vd._1 || vid == currentVertexId,
		math.min(vd._2, newSum.getOrElse(Double.MaxValue))))
	}
	g.outerJoinVertices(g2.vertices)((vid, vd, dist) =>
		(vd, dist.getOrElse((false,Double.MaxValue))._2))
}


2. 执行最短路径距离算法

val myVertices = sc.makeRDD(Array((1L, "A"), (2L, "B"), (3L, "C"),(4L, "D"), (5L, "E"), (6L, "F"), (7L, "G")))

val myEdges = sc.makeRDD(Array(Edge(1L, 2L, 7.0), Edge(1L, 4L, 5.0),Edge(2L, 3L, 8.0), Edge(2L, 4L, 9.0), Edge(2L, 5L, 7.0),Edge(3L, 5L, 5.0), Edge(4L, 5L, 15.0), Edge(4L, 6L, 6.0),Edge(5L, 6L, 8.0), Edge(5L, 7L, 9.0), Edge(6L, 7L, 11.0)))

val myGraph = Graph(myVertices, myEdges)
dijkstra(myGraph, 1L).vertices.map(_._2).collect

输出结果:
res0: Array[(String, Double)] = Array((D,5.0), (A,0.0), (F,11.0), (C,15.0), (G,22.0), (E,14.0), (B,7.0))


3. 包含路径记录的Dijkstra最短路径算法
	在1的基础上用一个List记录寻找的路径

import org.apache.spark.graphx._
def dijkstra[VD](g:Graph[VD,Double], origin:VertexId) = {
	var g2 = g.mapVertices(
		(vid,vd) => (false, if (vid == origin) 0 else Double.MaxValue,List[VertexId]()))

	for (i <- 1L to g.vertices.count-1) {
		val currentVertexId =
			g2.vertices.filter(!_._2._1)
				.fold((0L,(false,Double.MaxValue,List[VertexId]())))((a,b) =>
				if (a._2._2 < b._2._2) a else b)._1

		val newDistances = g2.aggregateMessages[(Double,List[VertexId])](
			ctx => if (ctx.srcId == currentVertexId)
				ctx.sendToDst((ctx.srcAttr._2 + ctx.attr,ctx.srcAttr._3 :+ ctx.srcId)),
			(a,b) => if (a._1 < b._1) a else b)
		g2 = g2.outerJoinVertices(newDistances)((vid, vd, newSum) => {
			val newSumVal = newSum.getOrElse((Double.MaxValue,List[VertexId]()))
			(vd._1 || vid == currentVertexId,
			math.min(vd._2, newSumVal._1),
			if (vd._2 < newSumVal._1) vd._3 else newSumVal._2)})
	}
	g.outerJoinVertices(g2.vertices)((vid, vd, dist) =>
		(vd, dist.getOrElse((false,Double.MaxValue,List[VertexId]())).productIterator.toList.tail))
}

4. 执行包含路径记录的Dijkstra最短路径算法

dijkstra(myGraph, 1L).vertices.map(_._2).collect

	输出结果:
	res1: Array[(String, List[Any])] = Array((D,List(5.0, List(1))), (A,List(0.0, List())), (F,List(11.0, List(1, 4))), (C,List(15.0, List(1, 2))), (G,List(22.0, List(1, 4, 6))), (E,List(14.0, List(1, 2))), (B,List(7.0, List(1))))

	结果解析:(G,List(22.0, List(1, 4, 6)))  1L到G的距离,分别经过1,4,6三个点,总距离为22.0

参考书籍:Spark GraphX 实战

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值