spark1.2.0源码分析之ShuffledRDD抓取数据

本文深入探讨了Spark Shuffle过程中如何从本地磁盘读取数据的详细步骤,包括ShuffleRDD的实现原理,从ShuffleManager获取读取器,以及如何通过HashShuffleReader和BlockStoreShuffleFetcher从文件中抓取数据。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

当经过shuffle写数据到本地磁盘后,需要从磁盘中将数据读取出来,这个是 ShuffledRDD 做的事情:

  override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
    val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
    SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
      .read()     // 每次reducer都读一个partition
      .asInstanceOf[Iterator[(K, C)]]
  }

可以看出,是通过 ShuffleManager.getReader方法来获得一个读取器,目前spark只有一种类型的读取器:HashShuffleReader,看一下具体源码:

  override def getReader[K, C](
      handle: ShuffleHandle,
      startPartition: Int,
      endPartition: Int,
      context: TaskContext): ShuffleReader[K, C] = {
    // We currently use the same block store shuffle fetcher as the hash-based shuffle.
    new HashShuffleReader(
      handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context)
  }
继续查看其 read() 方法:

 override def read(): Iterator[Product2[K, C]] = {
    val ser = Serializer.getSerializer(dep.serializer)
    <span style="color:#FF0000;">val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser)</span>  //真正的从file中抓取reducer所需的内容

    val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
      if (dep.mapSideCombine) {  //mapper端已经按key进行聚合了,此时,合并combiners
        new InterruptibleIterator(context, dep.aggregator.get.combineCombinersByKey(iter, context))
      } else {   //mapper端没有进行聚合,此时,合并values
        new InterruptibleIterator(context, dep.aggregator.get.combineValuesByKey(iter, context))
      }
    } else if (dep.aggregator.isEmpty && dep.mapSideCombine) {
      throw new IllegalStateException("Aggregator is empty for map-side combine")
    } else {  //没有聚合器,将其转换为键值对,因为之后的rdd需要这样的格式
      // Convert the Product2s to pairs since this is what downstream RDDs currently expect
      iter.asInstanceOf[Iterator[Product2[K, C]]].map(pair => (pair._1, pair._2))
    }

    // Sort the output if there is a sort ordering defined.
    dep.keyOrdering match {
      case Some(keyOrd: Ordering[K]) =>   //需要排序的情况
        // Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled,
        // the ExternalSorter won't spill to disk.
        val sorter = new ExternalSorter[K, C, C](ordering = Some(keyOrd), serializer = Some(ser))
        sorter.insertAll(aggregatedIter)
        context.taskMetrics.memoryBytesSpilled += sorter.memoryBytesSpilled
        context.taskMetrics.diskBytesSpilled += sorter.diskBytesSpilled
        sorter.iterator
      case None =>
        aggregatedIter  //最后返回一个迭代器
    }
  }

可以看出,真正读取磁盘文件的代码是:BlockStoreShuffleFetcher.fetch(),其源码如下:

private[hash] object BlockStoreShuffleFetcher extends Logging {
  def fetch[T](
      shuffleId: Int,
      reduceId: Int,
      context: TaskContext,
      serializer: Serializer)
    : Iterator[T] =
  {
    logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
    val blockManager = SparkEnv.get.blockManager

    val startTime = System.currentTimeMillis
    <span style="color:#FF0000;">val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId)</span>  //得到一个(BlockManagerId, Long)的数组,就是块管理器的地址与该shuffleId对应的map输出文件的offset
    logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format(
      shuffleId, reduceId, System.currentTimeMillis - startTime))

    val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(Int, Long)]]
    for (((address, size), index) <- statuses.zipWithIndex) {
      splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size))
    }

    val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])] = splitsByAddress.toSeq.map {
      case (address, splits) =>
        (address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2)))   //转换成相应的格式
    }

    def unpackBlock(blockPair: (BlockId, Try[Iterator[Any]])) : Iterator[T] = {
      val blockId = blockPair._1
      val blockOption = blockPair._2
      blockOption match {
        case Success(block) => {
          block.asInstanceOf[Iterator[T]]
        }
        case Failure(e) => {
          blockId match {
            case ShuffleBlockId(shufId, mapId, _) =>
              val address = statuses(mapId.toInt)._1
              throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, e)
            case _ =>
              throw new SparkException(
                "Failed to get block " + blockId + ", which is not a shuffle block", e)
          }
        }
      }
    }

    val blockFetcherItr = new ShuffleBlockFetcherIterator(   //抓取块数据的迭代器
      context,
      SparkEnv.get.blockManager.shuffleClient,  //所使用的块传输服务器
      blockManager,
      blocksByAddress,
      serializer,
      SparkEnv.get.conf.getLong("spark.reducer.maxMbInFlight", 48) * 1024 * 1024)  //远程块抓取的最大值,默认48M
    val itr = blockFetcherItr.flatMap(unpackBlock)

    val completionIter = CompletionIterator[T, Iterator[T]](itr, {
      context.taskMetrics.updateShuffleReadMetrics()
    })

    new InterruptibleIterator[T](context, completionIter)
  }
}
查看关键的一步代码:

  def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = {
    val statuses = mapStatuses.get(shuffleId).orNull
    if (statuses == null) {  //map输出文件不在本地,需要从远程抓取
      logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
      var fetchedStatuses: Array[MapStatus] = null
      fetching.synchronized {  //抓取文件是阻塞的
        if (fetching.contains(shuffleId)) {  //有人在抓取该文件,需要等待
          // Someone else is fetching it; wait for them to be done
          while (fetching.contains(shuffleId)) {
            try {
              fetching.wait()
            } catch {
              case e: InterruptedException =>
            }
          }
        }

        // Either while we waited the fetch happened successfully, or
        // someone fetched it in between the get and the fetching.synchronized.
        fetchedStatuses = mapStatuses.get(shuffleId).orNull  //继续看空本地是否有,有可能在等待的过程中,有人已经抓取到本地了
        if (fetchedStatuses == null) {
          // We have to do the fetch, get others to wait for us.
          fetching += shuffleId  //加入正在抓取的集合中
        }
      }

      if (fetchedStatuses == null) {
        // We won the race to fetch the output locs; do so
        logInfo("Doing the fetch; tracker actor = " + trackerActor)
        // This try-finally prevents hangs due to timeouts:
        try {
          val fetchedBytes =
            <span style="color:#FF0000;">askTracker(GetMapOutputStatuses(shuffleId)).asInstanceOf[Array[Byte]]</span>  //通过akka获取到需要抓取的字节数组
          fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes)  //反序列化
          logInfo("Got the output locations")
          mapStatuses.put(shuffleId, fetchedStatuses)
        } finally {
          fetching.synchronized {
            fetching -= shuffleId
            fetching.notifyAll()
          }
        }
      }
      if (fetchedStatuses != null) {
        fetchedStatuses.synchronized {
          return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses)  //转换之后,返回
        }
      } else {
        logError("Missing all output locations for shuffle " + shuffleId)
        throw new MetadataFetchFailedException(
          shuffleId, reduceId, "Missing all output locations for shuffle " + shuffleId)
      }
    } else {
      statuses.synchronized {   //数据在本地
        return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses)
      }
    }
  }

重点在于akka的通信,askTracker()方法:

  protected def askTracker(message: Any): Any = {
    try {
      val future = trackerActor.ask(message)(timeout)  //向trackActor发送消息
      Await.result(future, timeout)    //返回结果
    } catch {
      case e: Exception =>
        logError("Error communicating with MapOutputTracker", e)
        throw new SparkException("Error communicating with MapOutputTracker", e)
    }
  }
这里的trackActor实际上就是MapOutputTrackerMasterActor 的ActorRef,因为是worker,所以是他的一个引用,具体代码在SparkEnv中,如下:

   val mapOutputTracker =  if (isDriver) {
      new MapOutputTrackerMaster(conf)  //mapper阶段的输出跟踪器,主机
    } else {
      new MapOutputTrackerWorker(conf)  //mapper阶段的输出跟踪器,从机(输出的数据保存在本地work节点)
    }
    mapOutputTracker.trackerActor = registerOrLookup(
      "MapOutputTracker",
      new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf))
    def registerOrLookup(name: String, newActor: => Actor): ActorRef = {
      if (isDriver) {
        logInfo("Registering " + name)
        actorSystem.actorOf(Props(newActor), name = name)  //在drive上创建相应的Actor
      } else {
        AkkaUtils.makeDriverRef(name, conf, actorSystem)  //在每个worker节点上创建Actor引用(ActorRef)
      }
    }

继续回到原来的akka请求那里,MapOutputTrackerMasterActor收到消息后,进行处理,如下:

    case GetMapOutputStatuses(shuffleId: Int) =>
      val hostPort = sender.path.address.hostPort
      logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort)
      val mapOutputStatuses = tracker.getSerializedMapOutputStatuses(shuffleId)  //得到的是序列化后的字节数组数据
      val serializedSize = mapOutputStatuses.size
      if (serializedSize > maxAkkaFrameSize) {
        val msg = s"Map output statuses were $serializedSize bytes which " +
          s"exceeds spark.akka.frameSize ($maxAkkaFrameSize bytes)."

        /* For SPARK-1244 we'll opt for just logging an error and then throwing an exception.
         * Note that on exception the actor will just restart. A bigger refactoring (SPARK-1239)
         * will ultimately remove this entire code path. */
        val exception = new SparkException(msg)
        logError(msg, exception)
        throw exception
      }
      sender ! mapOutputStatuses  //将结果返回给发消息者


最后,回到原来的BlockStoreShuffleFetcher.fetch() 方法中,查看一下如下代码:

    val blockFetcherItr = new ShuffleBlockFetcherIterator(  //抓取块数据的迭代器
      context,
      SparkEnv.get.blockManager.shuffleClient,  //所使用的块传输服务器
      blockManager,
      blocksByAddress,
      serializer,
      SparkEnv.get.conf.getLong("spark.reducer.maxMbInFlight", 48) * 1024 * 1024)  //远程块抓取的最大值,默认48M
    val itr = blockFetcherItr.flatMap(unpackBlock)

    val completionIter = CompletionIterator[T, Iterator[T]](itr, {
      context.taskMetrics.updateShuffleReadMetrics()
    })

    new InterruptibleIterator[T](context, completionIter)  //最后封装之后返回

查看一下块传输服务器:

  private[spark] val shuffleClient = if (externalShuffleServiceEnabled) {
    val transConf = SparkTransportConf.fromSparkConf(conf, numUsableCores)
    new ExternalShuffleClient(transConf, securityManager, securityManager.isAuthenticationEnabled())  //从外部的服务器读取块(Executor外部)
  } else {
    blockTransferService   // 直接从其他Executor中读取块(如果Executor失败了,我们将不能读取在其内部shuffle数据了)
  }
blockTransferService 是在SparkEnv创建时指定的,默认使用netty服务器:

    val blockTransferService =
      conf.get("spark.shuffle.blockTransferService", "netty").toLowerCase match {
        case "netty" =>
          new NettyBlockTransferService(conf, securityManager, numUsableCores)
        case "nio" =>
          new NioBlockTransferService(conf, securityManager)
      }

************  The End  ************


评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值