Spark源码第四篇——从Executor执行Task开始到结果输出

Spark Shuffle 详解
本文深入解析了 Spark 中 shuffle 的工作原理,包括 Task 的序列化、反序列化过程,shuffle 的计算逻辑,以及 shuffle 过程中涉及的重要组件如 Executor 和 ShuffleManager 的作用。

上一篇传送门:https://blog.youkuaiyun.com/cw1254332663/article/details/95327497

学习的总结,不对之处请大家及时指正,谢谢啦!

上回书到,CoarseGrainedSchedulerBackend中的内部类DriverEndpoint把TaskDescription序列化,封装到LaunchTask样例类中。

CoarseGrainedExecutorBackend中有一个方法用于接受消息:

override def receive: PartialFunction[Any, Unit] = {
    // 接收到这个消息才会new一个executor
    case RegisteredExecutor =>
      logInfo("Successfully registered with driver")
      try {
        executor = new Executor(executorId, hostname, env, userClassPath, isLocal = false)
      } catch {
        case NonFatal(e) =>
          exitExecutor(1, "Unable to create executor due to " + e.getMessage, e)
      }

    case RegisterExecutorFailed(message) =>
      exitExecutor(1, "Slave registration failed: " + message)

    // 在这里我们模式匹配到这里
    case LaunchTask(data) =>
      // 判断Executor是否存在
      if (executor == null) {
        exitExecutor(1, "Received LaunchTask command but executor was null")
      } else {
        // 反序列化
        val taskDesc = TaskDescription.decode(data.value)
        logInfo("Got assigned task " + taskDesc.taskId)
        // 调用executor的执行方法来执行task
        executor.launchTask(this, taskDesc)
      }

    // 杀死
    case KillTask(taskId, _, interruptThread, reason) =>
      if (executor == null) {
        exitExecutor(1, "Received KillTask command but executor was null")
      } else {
        executor.killTask(taskId, interruptThread, reason)
      }

    case StopExecutor =>
      stopping.set(true)
      logInfo("Driver commanded a shutdown")
      self.send(Shutdown)

    case Shutdown =>
      stopping.set(true)
      new Thread("CoarseGrainedExecutorBackend-stop-executor") {
        override def run(): Unit = {
          executor.stop()
        }
      }.start()
  }

看一下executor.launchTask方法:

def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = {
    // 把ExecutorBackend和TaskDescription封装到Runnable中
    val tr = new TaskRunner(context, taskDescription)
    // 其中维护着正在运行的task的信息
    runningTasks.put(taskDescription.taskId, tr)
    // 把Runnable放到线程池中去执行
    threadPool.execute(tr)
  }

我们来看看TaskRunner中run方法的逻辑:

override def run(): Unit = {
      /*
       这里有一堆的信息获取代码
           ...... 
       */
      // 获取序列化工具
      val ser = env.closureSerializer.newInstance()
      logInfo(s"Running $taskName (TID $taskId)")
      // 向集群调度程序更新状态
      execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
      var taskStart: Long = 0
      var taskStartCpu: Long = 0
      startGCTime = computeTotalGcTime()
      try {
        Executor.taskDeserializationProps.set(taskDescription.properties)
        // 如果从SparkContext接收到一组新的文件和JAR,将会下载所有缺少的依赖项。
        // 并且用类加载器去加载新的jar。
        updateDependencies(taskDescription.addedFiles, taskDescription.addedJars)
        // 反序列化我们的task任务
        task = ser.deserialize[Task[Any]](
          taskDescription.serializedTask, Thread.currentThread.getContextClassLoader)
        task.localProperties = taskDescription.properties
        task.setTaskMemoryManager(taskMemoryManager)
        // 如果在反序列化此任务之前已将其终止,那么现在就会退出。否则,继续执行任务。
        val killReason = reasonIfKilled
        if (killReason.isDefined) {
          throw new TaskKilledException(killReason.get)
        }
        logDebug("Task " + taskId + "'s epoch is " + task.epoch)
        env.mapOutputTracker.updateEpoch(task.epoch)
        // 运行实际的task,并且测量运行时情况
        taskStart = System.currentTimeMillis()
        taskStartCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
          threadMXBean.getCurrentThreadCpuTime
        } else 0L
        var threwException = true
        val value = try {
          // 这里是主要的方法,运行task
          val res = task.run(
            taskAttemptId = taskId,
            attemptNumber = taskDescription.attemptNumber,
            metricsSystem = env.metricsSystem)
          threwException = false
          res
        } 
      }
        ......
    }

继续跟进task.run()方法中:

final def run(
      taskAttemptId: Long,
      attemptNumber: Int,
      metricsSystem: MetricsSystem): T = {
    SparkEnv.get.blockManager.registerTask(taskAttemptId)
   
    //这里初始化了一些类
        ......

    try {
      //在这里调用了runTask方法
      runTask(context)
    } 

    ......
  }

跟进shuffleMapTask.runTask方法:

override def runTask(context: TaskContext): MapStatus = {
    val threadMXBean = ManagementFactory.getThreadMXBean
    val deserializeStartTime = System.currentTimeMillis()
    val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
      threadMXBean.getCurrentThreadCpuTime
    } else 0L
    val ser = SparkEnv.get.closureSerializer.newInstance()
    
    // 使用广播变量反序列化RDD。taskBinary就是Broadcast[Array[Byte]]
    val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](
      ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
    _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime
    _executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
      threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime
    } else 0L

    var writer: ShuffleWriter[Any, Any] = null
    try {
      // 获取shuffleManager,
      // 有三种shuffleMananger:Hash、Sort(默认)、Tungsten-sort,之后会出篇文章专门讲解。
      val manager = SparkEnv.get.shuffleManager
      // 获取writer对象
      writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
      // 这里可以分为两步
      // 1. rdd的计算
      // 2. 计算结果的存储
      writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
      writer.stop(success = true).get
    } catch {
      case e: Exception =>
        try {
          if (writer != null) {
            writer.stop(success = false)
          }
        } catch {
          case e: Exception =>
            log.debug("Could not stop writer", e)
        }
        throw e
    }
  }

我们先看第一步RDD的计算:

final def iterator(split: Partition, context: TaskContext): Iterator[T] = {
    // 判断本RDD是否已经被持久化过
    if (storageLevel != StorageLevel.NONE) {
      // 获取或计算RDD的一个分区
      getOrCompute(split, context)
    } else {
      // 计算或读取checkpoint
      computeOrReadCheckpoint(split, context)
    }
  }

其中getOrCompute方法:

private[spark] def getOrCompute(partition: Partition, context: TaskContext): Iterator[T] = {
    // 获取rdd的缓存块id
    val blockId = RDDBlockId(id, partition.index)
    var readCachedBlock = true
    // 如果指定的缓存块存在,则检索该缓存块
    // 否则调用提供的“makeIterator”方法来计算该缓存块、保持该缓存块并返回其值。
    // 如果缓存块存在,则返回blockresult;如果缓存块不存在,则返回迭代器。
    SparkEnv.get.blockManager.getOrElseUpdate(blockId, storageLevel, elementClassTag, () => {
      readCachedBlock = false
      computeOrReadCheckpoint(partition, context)
    }) match {
      case Left(blockResult) =>
        if (readCachedBlock) {
          val existingMetrics = context.taskMetrics().inputMetrics
          existingMetrics.incBytesRead(blockResult.bytes)
          new InterruptibleIterator[T](context, blockResult.data.asInstanceOf[Iterator[T]]) {
            override def next(): T = {
              existingMetrics.incRecordsRead(1)
              delegate.next()
            }
          }
        } else {
          new InterruptibleIterator(context, blockResult.data.asInstanceOf[Iterator[T]])
        }
      case Right(iter) =>
        new InterruptibleIterator(context, iter.asInstanceOf[Iterator[T]])
    }
  }

computeOrReadCheckpoint方法:

private[spark] def computeOrReadCheckpoint(split: Partition, context: TaskContext): Iterator[T] =
  {
    if (isCheckpointedAndMaterialized) {
      //调用父RDD的iterator方法
      firstParent[T].iterator(split, context)
    } else {
      //调用计算逻辑进行计算
      compute(split, context)
    }
  }

继续跟进compute(非shuffle算子生成的MapPartitionsRDD)方法:

override def compute(split: Partition, context: TaskContext): Iterator[U] =
    // 此处调用父RDD的iterator方法
    f(context, split.index, firstParent[T].iterator(split, context))

而shuffle算子生成的RDD的compute方法:

override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = {
    val shuffledRowPartition = split.asInstanceOf[ShuffledRowRDDPartition]
    // 获取reader对象
    val reader =
      SparkEnv.get.shuffleManager.getReader(
        dependency.shuffleHandle,
        shuffledRowPartition.startPreShufflePartitionIndex,
        shuffledRowPartition.endPreShufflePartitionIndex,
        context)
    // 调用read()方法,该方法向BlockManager获取上个rdd生成的中间文件的位置
    reader.read().asInstanceOf[Iterator[Product2[Int, InternalRow]]].map(_._2)
  }

我们再看第二步writer.write():

override def write(records: Iterator[Product2[K, V]]): Unit = {
    sorter = if (dep.mapSideCombine) {
      require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
      // ExternalSorter中维护着两个内存结构,如果是聚合类算子如reduceByKey
      // 在溢出之前先放到Map结构中,在每次写入Map时需要判断大小是否到达一定阈值
      // 到达之后溢出到缓冲区并清空该数据结构中的数据
      // 缓冲区(默认32K)满了之后落地成文件
      // 如果是join类普通的shuffle算子则使用Array结构
      new ExternalSorter[K, V, C](
        context, dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
    } else {
      // 在这种情况下,我们既不向ExternalSorter传递aggregate,也不向ExternalSorter传递sort
      // 因为我们不用关心key是否在每个分区中排序;
      // 如果正在运行的操作是sortbykey,那么将在reduce端进行排序。
      new ExternalSorter[K, V, V](
        context, aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer)
    }
    // 向内存结构(Map或Array)中写入数据,每次写入都需要判断是否要溢出
    sorter.insertAll(records)

    // 在sortShuffleManager中,一个task输出一个文件和一个索引文件
    val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
    // 目录中创建一个临时文件
    val tmp = Utils.tempFileWith(output)
    try {
      val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
      // 将添加到ExternalSorter中的所有数据写入磁盘
      val partitionLengths = sorter.writePartitionedFile(blockId, tmp)
      // 写一个索引文件,每个块的偏移量加上输出文件末尾的最终偏移量。
      // 这被用于getBlockData获取数据时确定每个块的位置。
      shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp)
      mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
    } finally {
      if (tmp.exists() && !tmp.delete()) {
        logError(s"Error while deleting temp file ${tmp.getAbsolutePath}")
      }
    }
  }

这里有点多,以后慢慢的补充......

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值