Spark源码分析(七):Task分析

本文详细解析了Spark中Task的执行流程,包括TaskRunner的创建、Task的反序列化及执行过程,特别针对ShuffleMapTask的runTask方法进行了深入探讨。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

Task

这篇主要看一下在executor接收到LaunchTask消息之后,都做了什么
Executor

def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = {
 // 对每个task创建一个TaskRunner线程
 val tr = new TaskRunner(context, taskDescription)

 runningTasks.put(taskDescription.taskId, tr)
 threadPool.execute(tr)
}

从上面可以看出,首先对每个task创建了一个TaskRunner
然后使用线程池来执行这个TaskRunner
TaskRunner

override def run(): Unit = {
   // 获得线程的id
   threadId = Thread.currentThread.getId
   // 设置当前线程的名称
   Thread.currentThread.setName(threadName)
   val threadMXBean = ManagementFactory.getThreadMXBean
   val taskMemoryManager = new TaskMemoryManager(env.memoryManager, taskId)
   // 反序列化开始时间以及开始cpu时间
   val deserializeStartTime = System.currentTimeMillis()
   val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
     threadMXBean.getCurrentThreadCpuTime
   } else 0L
   Thread.currentThread.setContextClassLoader(replClassLoader)
   val ser = env.closureSerializer.newInstance()
   logInfo(s"Running $taskName (TID $taskId)")
   execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
   var taskStartTime: Long = 0
   var taskStartCpu: Long = 0
   startGCTime = computeTotalGcTime()

   try {
     // Must be set before updateDependencies() is called, in case fetching dependencies
     // requires access to properties contained within (e.g. for access control).
     Executor.taskDeserializationProps.set(taskDescription.properties)

     // 将所需的文件或者jar下载到本地
     updateDependencies(taskDescription.addedFiles, taskDescription.addedJars)

     // 反序列化具体任务
     // 使用java的classLoader的原因,可以提供很多功能,可以映射获取类信息,从而创建对象
     task = ser.deserialize[Task[Any]](
       taskDescription.serializedTask, Thread.currentThread.getContextClassLoader)

     task.localProperties = taskDescription.properties
     task.setTaskMemoryManager(taskMemoryManager)

     // If this task has been killed before we deserialized it, let's quit now. Otherwise,
     // continue executing the task.
     val killReason = reasonIfKilled
     if (killReason.isDefined) {
       // Throw an exception rather than returning, because returning within a try{} block
       // causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl
       // exception will be caught by the catch block, leading to an incorrect ExceptionFailure
       // for the task.
       throw new TaskKilledException(killReason.get)
     }

     // The purpose of updating the epoch here is to invalidate executor map output status cache
     // in case FetchFailures have occurred. In local mode `env.mapOutputTracker` will be
     // MapOutputTrackerMaster and its cache invalidation is not based on epoch numbers so
     // we don't need to make any special calls here.
     if (!isLocal) {
       logDebug("Task " + taskId + "'s epoch is " + task.epoch)
       env.mapOutputTracker.asInstanceOf[MapOutputTrackerWorker].updateEpoch(task.epoch)
     }

     // 最关键的部分
     // Run the actual task and measure its runtime.

     taskStartTime = System.currentTimeMillis()
     taskStartCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
       threadMXBean.getCurrentThreadCpuTime
     } else 0L
     var threwException = true
     // 这里的value实际上就是MapStatus
     // 封装了ShuffleMapTask计算的数据,输出的位置
     val value = Utils.tryWithSafeFinally {
       // 使用task的run方法
       val res = task.run(
         taskAttemptId = taskId,
         attemptNumber = taskDescription.attemptNumber,
         metricsSystem = env.metricsSystem)
       threwException = false
       res
     } {
       val releasedLocks = env.blockManager.releaseAllLocksForTask(taskId)
       val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory()

       if (freedMemory > 0 && !threwException) {
         val errMsg = s"Managed memory leak detected; size = $freedMemory bytes, TID = $taskId"
         if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false)) {
           throw new SparkException(errMsg)
         } else {
           logWarning(errMsg)
         }
       }

       if (releasedLocks.nonEmpty && !threwException) {
         val errMsg =
           s"${releasedLocks.size} block locks were not released by TID = $taskId:\n" +
             releasedLocks.mkString("[", ", ", "]")
         if (conf.getBoolean("spark.storage.exceptionOnPinLeak", false)) {
           throw new SparkException(errMsg)
         } else {
           logInfo(errMsg)
         }
       }
     }
     task.context.fetchFailed.foreach { fetchFailure =>
       // uh-oh.  it appears the user code has caught the fetch-failure without throwing any
       // other exceptions.  Its *possible* this is what the user meant to do (though highly
       // unlikely).  So we will log an error and keep going.
       logError(s"TID ${taskId} completed successfully though internally it encountered " +
         s"unrecoverable fetch failures!  Most likely this means user code is incorrectly " +
         s"swallowing Spark's internal ${classOf[FetchFailedException]}", fetchFailure)
     }
     val taskFinish = System.currentTimeMillis()
     val taskFinishCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
       threadMXBean.getCurrentThreadCpuTime
     } else 0L

     // If the task has been killed, let's fail it.
     task.context.killTaskIfInterrupted()

     /**
       * 这个其实就是对MapStatus进行了各种序列化和封装,因为后面要发送给Driver(通过网络)
       */
     val resultSer = env.serializer.newInstance()
     val beforeSerialization = System.currentTimeMillis()
     val valueBytes = resultSer.serialize(value)
     val afterSerialization = System.currentTimeMillis()

     // Deserialization happens in two parts: first, we deserialize a Task object, which
     // includes the Partition. Second, Task.run() deserializes the RDD and function to be run.
     // 下面是一些统计信息
     task.metrics.setExecutorDeserializeTime(
       (taskStartTime - deserializeStartTime) + task.executorDeserializeTime)
     task.metrics.setExecutorDeserializeCpuTime(
       (taskStartCpu - deserializeStartCpuTime) + task.executorDeserializeCpuTime)
     // We need to subtract Task.run()'s deserialization time to avoid double-counting
     task.metrics.setExecutorRunTime((taskFinish - taskStartTime) - task.executorDeserializeTime)
     task.metrics.setExecutorCpuTime(
       (taskFinishCpu - taskStartCpu) - task.executorDeserializeCpuTime)
     task.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime)
     task.metrics.setResultSerializationTime(afterSerialization - beforeSerialization)

     // Expose task metrics using the Dropwizard metrics system.
     // Update task metrics counters
     executorSource.METRIC_CPU_TIME.inc(task.metrics.executorCpuTime)
     executorSource.METRIC_RUN_TIME.inc(task.metrics.executorRunTime)
     executorSource.METRIC_JVM_GC_TIME.inc(task.metrics.jvmGCTime)
     executorSource.METRIC_DESERIALIZE_TIME.inc(task.metrics.executorDeserializeTime)
     executorSource.METRIC_DESERIALIZE_CPU_TIME.inc(task.metrics.executorDeserializeCpuTime)
     executorSource.METRIC_RESULT_SERIALIZE_TIME.inc(task.metrics.resultSerializationTime)
     executorSource.METRIC_SHUFFLE_FETCH_WAIT_TIME
       .inc(task.metrics.shuffleReadMetrics.fetchWaitTime)
     executorSource.METRIC_SHUFFLE_WRITE_TIME.inc(task.metrics.shuffleWriteMetrics.writeTime)
     executorSource.METRIC_SHUFFLE_TOTAL_BYTES_READ
       .inc(task.metrics.shuffleReadMetrics.totalBytesRead)
     executorSource.METRIC_SHUFFLE_REMOTE_BYTES_READ
       .inc(task.metrics.shuffleReadMetrics.remoteBytesRead)
     executorSource.METRIC_SHUFFLE_REMOTE_BYTES_READ_TO_DISK
       .inc(task.metrics.shuffleReadMetrics.remoteBytesReadToDisk)
     executorSource.METRIC_SHUFFLE_LOCAL_BYTES_READ
       .inc(task.metrics.shuffleReadMetrics.localBytesRead)
     executorSource.METRIC_SHUFFLE_RECORDS_READ
       .inc(task.metrics.shuffleReadMetrics.recordsRead)
     executorSource.METRIC_SHUFFLE_REMOTE_BLOCKS_FETCHED
       .inc(task.metrics.shuffleReadMetrics.remoteBlocksFetched)
     executorSource.METRIC_SHUFFLE_LOCAL_BLOCKS_FETCHED
       .inc(task.metrics.shuffleReadMetrics.localBlocksFetched)
     executorSource.METRIC_SHUFFLE_BYTES_WRITTEN
       .inc(task.metrics.shuffleWriteMetrics.bytesWritten)
     executorSource.METRIC_SHUFFLE_RECORDS_WRITTEN
       .inc(task.metrics.shuffleWriteMetrics.recordsWritten)
     executorSource.METRIC_INPUT_BYTES_READ
       .inc(task.metrics.inputMetrics.bytesRead)
     executorSource.METRIC_INPUT_RECORDS_READ
       .inc(task.metrics.inputMetrics.recordsRead)
     executorSource.METRIC_OUTPUT_BYTES_WRITTEN
       .inc(task.metrics.outputMetrics.bytesWritten)
     executorSource.METRIC_OUTPUT_RECORDS_WRITTEN
       .inc(task.metrics.outputMetrics.recordsWritten)
     executorSource.METRIC_RESULT_SIZE.inc(task.metrics.resultSize)
     executorSource.METRIC_DISK_BYTES_SPILLED.inc(task.metrics.diskBytesSpilled)
     executorSource.METRIC_MEMORY_BYTES_SPILLED.inc(task.metrics.memoryBytesSpilled)

     // Note: accumulator updates must be collected after TaskMetrics is updated
     val accumUpdates = task.collectAccumulatorUpdates()
     // TODO: do not serialize value twice
     val directResult = new DirectTaskResult(valueBytes, accumUpdates)
     val serializedDirectResult = ser.serialize(directResult)
     val resultSize = serializedDirectResult.limit()

     // directSend = sending directly back to the driver
     val serializedResult: ByteBuffer = {
       if (maxResultSize > 0 && resultSize > maxResultSize) {
         logWarning(s"Finished $taskName (TID $taskId). Result is larger than maxResultSize " +
           s"(${Utils.bytesToString(resultSize)} > ${Utils.bytesToString(maxResultSize)}), " +
           s"dropping it.")
         ser.serialize(new IndirectTaskResult[Any](TaskResultBlockId(taskId), resultSize))
       } else if (resultSize > maxDirectResultSize) {
         val blockId = TaskResultBlockId(taskId)
         env.blockManager.putBytes(
           blockId,
           new ChunkedByteBuffer(serializedDirectResult.duplicate()),
           StorageLevel.MEMORY_AND_DISK_SER)
         logInfo(
           s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)")
         ser.serialize(new IndirectTaskResult[Any](blockId, resultSize))
       } else {
         logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result sent to driver")
         serializedDirectResult
       }
     }

     setTaskFinishedAndClearInterruptStatus()
     // 调用了executor所在的CoarseExecutorSchedulerBackend的statusBackend
     // 其中serializedResult就包含有task计算结果位置信息
     // 下面的方法最终会向driver发送task的最新状态,以及计算结果位置信息
     execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)

   } catch {
     case t: TaskKilledException =>
       logInfo(s"Executor killed $taskName (TID $taskId), reason: ${t.reason}")

       val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTime)
       val serializedTK = ser.serialize(TaskKilled(t.reason, accUpdates, accums))
       execBackend.statusUpdate(taskId, TaskState.KILLED, serializedTK)

     case _: InterruptedException | NonFatal(_) if
         task != null && task.reasonIfKilled.isDefined =>
       val killReason = task.reasonIfKilled.getOrElse("unknown reason")
       logInfo(s"Executor interrupted and killed $taskName (TID $taskId), reason: $killReason")

       val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTime)
       val serializedTK = ser.serialize(TaskKilled(killReason, accUpdates, accums))
       execBackend.statusUpdate(taskId, TaskState.KILLED, serializedTK)

     case t: Throwable if hasFetchFailure && !Utils.isFatalError(t) =>
       val reason = task.context.fetchFailed.get.toTaskFailedReason
       if (!t.isInstanceOf[FetchFailedException]) {
         // there was a fetch failure in the task, but some user code wrapped that exception
         // and threw something else.  Regardless, we treat it as a fetch failure.
         val fetchFailedCls = classOf[FetchFailedException].getName
         logWarning(s"TID ${taskId} encountered a ${fetchFailedCls} and " +
           s"failed, but the ${fetchFailedCls} was hidden by another " +
           s"exception.  Spark is handling this like a fetch failure and ignoring the " +
           s"other exception: $t")
       }
       setTaskFinishedAndClearInterruptStatus()
       execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))

     case CausedBy(cDE: CommitDeniedException) =>
       val reason = cDE.toTaskCommitDeniedReason
       setTaskFinishedAndClearInterruptStatus()
       execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(reason))

     case t: Throwable =>
       // Attempt to exit cleanly by informing the driver of our failure.
       // If anything goes wrong (or this was a fatal exception), we will delegate to
       // the default uncaught exception handler, which will terminate the Executor.
       logError(s"Exception in $taskName (TID $taskId)", t)

       // SPARK-20904: Do not report failure to driver if if happened during shut down. Because
       // libraries may set up shutdown hooks that race with running tasks during shutdown,
       // spurious failures may occur and can result in improper accounting in the driver (e.g.
       // the task failure would not be ignored if the shutdown happened because of premption,
       // instead of an app issue).
       if (!ShutdownHookManager.inShutdown()) {
         val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTime)

         val serializedTaskEndReason = {
           try {
             ser.serialize(new ExceptionFailure(t, accUpdates).withAccums(accums))
           } catch {
             case _: NotSerializableException =>
               // t is not serializable so just send the stacktrace
               ser.serialize(new ExceptionFailure(t, accUpdates, false).withAccums(accums))
           }
         }
         setTaskFinishedAndClearInterruptStatus()
         execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskEndReason)
       } else {
         logInfo("Not reporting error to driver during JVM shutdown.")
       }

       // Don't forcibly exit unless the exception was inherently fatal, to avoid
       // stopping other tasks unnecessarily.
       if (!t.isInstanceOf[SparkOutOfMemoryError] && Utils.isFatalError(t)) {
         uncaughtExceptionHandler.uncaughtException(Thread.currentThread(), t)
       }
   } finally {
     runningTasks.remove(taskId)
   }
 }

当对task调用run(),实际上会执行task的runTask()
下面以ShuffleMapTask举例说明

/*
* 非常重要的一点就是,该方法有MapStatus返回值
* */
override def runTask(context: TaskContext): MapStatus = {
  // Deserialize the RDD using the broadcast variable.
  // 对task要处理的rdd相关数据,做一些反序列化操作
  // 这个Rdd,关键是怎么拿到的,多个task运行在多个executor中,并行或者并发运行的
  // 但是,一个stage的task,要处理的rdd是一样的
  // 所以task怎么拿到自己要处理的那个rdd的数据呢
  // 通过广播变量拿到
  val threadMXBean = ManagementFactory.getThreadMXBean
  val deserializeStartTime = System.currentTimeMillis()
  val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
    threadMXBean.getCurrentThreadCpuTime
  } else 0L
  val ser = SparkEnv.get.closureSerializer.newInstance()
  val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](
    ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
  _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime
  _executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
    threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime
  } else 0L

  var writer: ShuffleWriter[Any, Any] = null
  try {
    // 获取shuffleMnager
    val manager = SparkEnv.get.shuffleManager
    writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
    // 首先调用了rdd的iterator并且传入了当前task要处理哪个partition
    // 核心地逻辑,就在rdd的iteretor中
    // 返回的数据,都是通过ShuffleWriter,经过HashPartitioner进行分区后
    // 写入自己对应的分区bucket中
    // 默认的writer是HashWriter
    // 在spark2.0中已经移除了HashWriter,使用了SortWriter
    writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
    // 最后,返回结果MapStatus,封装了ShuffleMapTask计算后的数据
    // 存储在哪里,其实就是BlockManager的相关的信息
    // BlockManager,是Spark的内存,数据,磁盘管理工具
    writer.stop(success = true).get
  } catch {
    case e: Exception =>
      try {
        if (writer != null) {
          writer.stop(success = false)
        }
      } catch {
        case e: Exception =>
          log.debug("Could not stop writer", e)
      }
      throw e
  }
}

流程总结:

  1. Executor接收到LaunchTask消息之后,会创建一个TaskRunner,然后使用线程池运行该TaskRunner
  2. 在TaskRunner运行的过程中,会将task反序列化,下载文件和jar,接着执行task的run方法,task的run方法中会执行task的计算(底层调用的是rdd的iterate()),如果如果是ShuffleMapTask,task运行完毕后,会返回MapStatus对象,该对象记录了当前task运行的位置,以及对应每个reducer的输出的大小,在向driver发送task状态更新信息的同时,会将MapStatus也发送过去
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值