当经过shuffle写数据到本地磁盘后,需要从磁盘中将数据读取出来,这个是 ShuffledRDD 做的事情:
override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
.read() // 每次reducer都读一个partition
.asInstanceOf[Iterator[(K, C)]]
}
可以看出,是通过 ShuffleManager.getReader方法来获得一个读取器,目前spark只有一种类型的读取器:HashShuffleReader,看一下具体源码:
override def getReader[K, C](
handle: ShuffleHandle,
startPartition: Int,
endPartition: Int,
context: TaskContext): ShuffleReader[K, C] = {
// We currently use the same block store shuffle fetcher as the hash-based shuffle.
new HashShuffleReader(
handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context)
}
继续查看其 read() 方法:
override def read(): Iterator[Product2[K, C]] = {
val ser = Serializer.getSerializer(dep.serializer)
<span style="color:#FF0000;">val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser)</span> //真正的从file中抓取reducer所需的内容
val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
if (dep.mapSideCombine) { //mapper端已经按key进行聚合了,此时,合并combiners
new InterruptibleIterator(context, dep.aggregator.get.combineCombinersByKey(iter, context))
} else { //mapper端没有进行聚合,此时,合并values
new InterruptibleIterator(context, dep.aggregator.get.combineValuesByKey(iter, context))
}
} else if (dep.aggregator.isEmpty && dep.mapSideCombine) {
throw new IllegalStateException("Aggregator is empty for map-side combine")
} else { //没有聚合器,将其转换为键值对,因为之后的rdd需要这样的格式
// Convert the Product2s to pairs since this is what downstream RDDs currently expect
iter.asInstanceOf[Iterator[Product2[K, C]]].map(pair => (pair._1, pair._2))
}
// Sort the output if there is a sort ordering defined.
dep.keyOrdering match {
case Some(keyOrd: Ordering[K]) => //需要排序的情况
// Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled,
// the ExternalSorter won't spill to disk.
val sorter = new ExternalSorter[K, C, C](ordering = Some(keyOrd), serializer = Some(ser))
sorter.insertAll(aggregatedIter)
context.taskMetrics.memoryBytesSpilled += sorter.memoryBytesSpilled
context.taskMetrics.diskBytesSpilled += sorter.diskBytesSpilled
sorter.iterator
case None =>
aggregatedIter //最后返回一个迭代器
}
}
可以看出,真正读取磁盘文件的代码是:BlockStoreShuffleFetcher.fetch(),其源码如下:
private[hash] object BlockStoreShuffleFetcher extends Logging {
def fetch[T](
shuffleId: Int,
reduceId: Int,
context: TaskContext,
serializer: Serializer)
: Iterator[T] =
{
logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
val blockManager = SparkEnv.get.blockManager
val startTime = System.currentTimeMillis
<span style="color:#FF0000;">val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId)</span> //得到一个(BlockManagerId, Long)的数组,就是块管理器的地址与该shuffleId对应的map输出文件的offset
logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format(
shuffleId, reduceId, System.currentTimeMillis - startTime))
val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(Int, Long)]]
for (((address, size), index) <- statuses.zipWithIndex) {
splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size))
}
val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])] = splitsByAddress.toSeq.map {
case (address, splits) =>
(address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2))) //转换成相应的格式
}
def unpackBlock(blockPair: (BlockId, Try[Iterator[Any]])) : Iterator[T] = {
val blockId = blockPair._1
val blockOption = blockPair._2
blockOption match {
case Success(block) => {
block.asInstanceOf[Iterator[T]]
}
case Failure(e) => {
blockId match {
case ShuffleBlockId(shufId, mapId, _) =>
val address = statuses(mapId.toInt)._1
throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, e)
case _ =>
throw new SparkException(
"Failed to get block " + blockId + ", which is not a shuffle block", e)
}
}
}
}
val blockFetcherItr = new ShuffleBlockFetcherIterator( //抓取块数据的迭代器
context,
SparkEnv.get.blockManager.shuffleClient, //所使用的块传输服务器
blockManager,
blocksByAddress,
serializer,
SparkEnv.get.conf.getLong("spark.reducer.maxMbInFlight", 48) * 1024 * 1024) //远程块抓取的最大值,默认48M
val itr = blockFetcherItr.flatMap(unpackBlock)
val completionIter = CompletionIterator[T, Iterator[T]](itr, {
context.taskMetrics.updateShuffleReadMetrics()
})
new InterruptibleIterator[T](context, completionIter)
}
}
查看关键的一步代码:
def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = {
val statuses = mapStatuses.get(shuffleId).orNull
if (statuses == null) { //map输出文件不在本地,需要从远程抓取
logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
var fetchedStatuses: Array[MapStatus] = null
fetching.synchronized { //抓取文件是阻塞的
if (fetching.contains(shuffleId)) { //有人在抓取该文件,需要等待
// Someone else is fetching it; wait for them to be done
while (fetching.contains(shuffleId)) {
try {
fetching.wait()
} catch {
case e: InterruptedException =>
}
}
}
// Either while we waited the fetch happened successfully, or
// someone fetched it in between the get and the fetching.synchronized.
fetchedStatuses = mapStatuses.get(shuffleId).orNull //继续看空本地是否有,有可能在等待的过程中,有人已经抓取到本地了
if (fetchedStatuses == null) {
// We have to do the fetch, get others to wait for us.
fetching += shuffleId //加入正在抓取的集合中
}
}
if (fetchedStatuses == null) {
// We won the race to fetch the output locs; do so
logInfo("Doing the fetch; tracker actor = " + trackerActor)
// This try-finally prevents hangs due to timeouts:
try {
val fetchedBytes =
<span style="color:#FF0000;">askTracker(GetMapOutputStatuses(shuffleId)).asInstanceOf[Array[Byte]]</span> //通过akka获取到需要抓取的字节数组
fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes) //反序列化
logInfo("Got the output locations")
mapStatuses.put(shuffleId, fetchedStatuses)
} finally {
fetching.synchronized {
fetching -= shuffleId
fetching.notifyAll()
}
}
}
if (fetchedStatuses != null) {
fetchedStatuses.synchronized {
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses) //转换之后,返回
}
} else {
logError("Missing all output locations for shuffle " + shuffleId)
throw new MetadataFetchFailedException(
shuffleId, reduceId, "Missing all output locations for shuffle " + shuffleId)
}
} else {
statuses.synchronized { //数据在本地
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses)
}
}
}
重点在于akka的通信,askTracker()方法:
protected def askTracker(message: Any): Any = {
try {
val future = trackerActor.ask(message)(timeout) //向trackActor发送消息
Await.result(future, timeout) //返回结果
} catch {
case e: Exception =>
logError("Error communicating with MapOutputTracker", e)
throw new SparkException("Error communicating with MapOutputTracker", e)
}
}
这里的trackActor实际上就是MapOutputTrackerMasterActor 的ActorRef,因为是worker,所以是他的一个引用,具体代码在SparkEnv中,如下:
val mapOutputTracker = if (isDriver) {
new MapOutputTrackerMaster(conf) //mapper阶段的输出跟踪器,主机
} else {
new MapOutputTrackerWorker(conf) //mapper阶段的输出跟踪器,从机(输出的数据保存在本地work节点)
}
mapOutputTracker.trackerActor = registerOrLookup(
"MapOutputTracker",
new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf))
def registerOrLookup(name: String, newActor: => Actor): ActorRef = {
if (isDriver) {
logInfo("Registering " + name)
actorSystem.actorOf(Props(newActor), name = name) //在drive上创建相应的Actor
} else {
AkkaUtils.makeDriverRef(name, conf, actorSystem) //在每个worker节点上创建Actor引用(ActorRef)
}
}
继续回到原来的akka请求那里,MapOutputTrackerMasterActor收到消息后,进行处理,如下:
case GetMapOutputStatuses(shuffleId: Int) =>
val hostPort = sender.path.address.hostPort
logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort)
val mapOutputStatuses = tracker.getSerializedMapOutputStatuses(shuffleId) //得到的是序列化后的字节数组数据
val serializedSize = mapOutputStatuses.size
if (serializedSize > maxAkkaFrameSize) {
val msg = s"Map output statuses were $serializedSize bytes which " +
s"exceeds spark.akka.frameSize ($maxAkkaFrameSize bytes)."
/* For SPARK-1244 we'll opt for just logging an error and then throwing an exception.
* Note that on exception the actor will just restart. A bigger refactoring (SPARK-1239)
* will ultimately remove this entire code path. */
val exception = new SparkException(msg)
logError(msg, exception)
throw exception
}
sender ! mapOutputStatuses //将结果返回给发消息者
最后,回到原来的BlockStoreShuffleFetcher.fetch() 方法中,查看一下如下代码:
val blockFetcherItr = new ShuffleBlockFetcherIterator( //抓取块数据的迭代器
context,
SparkEnv.get.blockManager.shuffleClient, //所使用的块传输服务器
blockManager,
blocksByAddress,
serializer,
SparkEnv.get.conf.getLong("spark.reducer.maxMbInFlight", 48) * 1024 * 1024) //远程块抓取的最大值,默认48M
val itr = blockFetcherItr.flatMap(unpackBlock)
val completionIter = CompletionIterator[T, Iterator[T]](itr, {
context.taskMetrics.updateShuffleReadMetrics()
})
new InterruptibleIterator[T](context, completionIter) //最后封装之后返回
查看一下块传输服务器:
private[spark] val shuffleClient = if (externalShuffleServiceEnabled) {
val transConf = SparkTransportConf.fromSparkConf(conf, numUsableCores)
new ExternalShuffleClient(transConf, securityManager, securityManager.isAuthenticationEnabled()) //从外部的服务器读取块(Executor外部)
} else {
blockTransferService // 直接从其他Executor中读取块(如果Executor失败了,我们将不能读取在其内部shuffle数据了)
}
blockTransferService 是在SparkEnv创建时指定的,默认使用netty服务器:
val blockTransferService =
conf.get("spark.shuffle.blockTransferService", "netty").toLowerCase match {
case "netty" =>
new NettyBlockTransferService(conf, securityManager, numUsableCores)
case "nio" =>
new NioBlockTransferService(conf, securityManager)
}
************ The End ************