spark 业务执行进程

本文详细介绍了Spark任务的执行流程,从CoarseGrainedExecutorBackend作为业务执行进程的入口,到WorkerWatcher线程的启动,再到与Driver的连接和任务注册。通过NettyRpcEnv或AkkaRpcEnv实现通信,Executor类负责任务的执行,包括反序列化任务和依赖,以及Task(如ShuffleMapTask和ResultTask)的执行。任务完成后,结果回传给Driver。整个过程涉及任务的接收、依赖管理和结果反馈。

spark 业务执行进程

spark的task的业务执行进程的入口类为 CoarseGrainedExecutorBackend

在CoarseGrainedExecutorBackend进程当中,同时也拉起了 WorkerWatcher线程来和当前机器的work进程进行通信

  // 这里就拉起了 CoarseGrainedExecutorBackend 线程了,这个线程对象和spark driver进行通信连接,接收执行任务
  env.rpcEnv.setupEndpoint("Executor", new CoarseGrainedExecutorBackend(
    env.rpcEnv, driverUrl, executorId, sparkHostPort, cores, userClassPath, env))
  // 这里实现了和 spark work进程的通信,当spark work进程 down掉时,这个线程监听到,然后也自动中止了
  workerUrl.foreach { url =>
    env.rpcEnv.setupEndpoint("WorkerWatcher", new WorkerWatcher(env.rpcEnv, url))
  }

通过接收相关的参数调用如下的入口,然后进行创建相关的类

private[spark] class CoarseGrainedExecutorBackend(
override val rpcEnv: RpcEnv,
driverUrl: String,
executorId: String,
hostPort: String,
cores: Int,
userClassPath: Seq[URL],
env: SparkEnv)
  extends ThreadSafeRpcEndpoint with ExecutorBackend with Logging {

  var executor: Executor = null
  @volatile var driver: Option[RpcEndpointRef] = None

  // If this CoarseGrainedExecutorBackend is changed to support multiple threads, then this may need
  // to be changed so that we don't share the serializer instance across threads
  private[this] val ser: SerializerInstance = env.closureSerializer.newInstance()

  override def onStart() {
logInfo("Connecting to driver: " + driverUrl)
// 这里开始连接到driver了
rpcEnv.asyncSetupEndpointRefByURI(driverUrl).flatMap { ref =>
  // This is a very fast action so we can use "ThreadUtils.sameThread"
  driver = Some(ref)
  // 向driver进行了注册了
  ref.ask[RegisterExecutorResponse](
RegisterExecutor(executorId, self, hostPort, cores, extractLogUrls))
}(ThreadUtils.sameThread).onComplete {
  // This is a very fast action so we can use "ThreadUtils.sameThread"
  case Success(msg) => Utils.tryLogNonFatalError {
// 注册完成后调用自身的消息
Option(self).foreach(_.send(msg)) // msg must be RegisterExecutorResponse
  }
  case Failure(e) => {
logError(s"Cannot register with driver: $driverUrl", e)
System.exit(1)
  }
}(ThreadUtils.sameThread)
  }

  def extractLogUrls: Map[String, String] = {
val prefix = "SPARK_LOG_URL_"
sys.env.filterKeys(_.startsWith(prefix))
  .map(e => (e._1.substring(prefix.length).toLowerCase, e._2))
  }
  // 这里就是消息的接收器了
  override def receive: PartialFunction[Any, Unit] = {
case RegisteredExecutor(hostname) =>
  logInfo("Successfully registered with driver")
  // 和driver注册成功了
  executor = new Executor(executorId, hostname, env, userClassPath, isLocal = false)

case RegisterExecutorFailed(message) =>
  logError("Slave registration failed: " + message)
  System.exit(1)

case LaunchTask(data) =>
  if (executor == null) {
logError("Received LaunchTask command but executor was null")
System.exit(1)
  } else {
// 执行的任务反序列化
val taskDesc = ser.deserialize[TaskDescription](data.value)
logInfo("Got assigned task " + taskDesc.taskId)
// 执行任务了
executor.launchTask(this, taskId = taskDesc.taskId, attemptNumber = taskDesc.attemptNumber,
  taskDesc.name, taskDesc.serializedTask)
  }

case KillTask(taskId, _, interruptThread) =>
  if (executor == null) {
logError("Received KillTask command but executor was null")
System.exit(1)
  } else {
executor.killTask(taskId, interruptThread)
  }

case StopExecutor =>
  logInfo("Driver commanded a shutdown")
  // Cannot shutdown here because an ack may need to be sent back to the caller. So send
  // a message to self to actually do the shutdown.
  self.send(Shutdown)

case Shutdown =>
  executor.stop()
  stop()
  rpcEnv.shutdown()
  }

  override def onDisconnected(remoteAddress: RpcAddress): Unit = {
if (driver.exists(_.address == remoteAddress)) {
  logError(s"Driver $remoteAddress disassociated! Shutting down.")
  // driver主动断开连接,所以当前进程也进行了退出了
  System.exit(1)
} else {
  logWarning(s"An unknown ($remoteAddress) driver disconnected.")
}
  }
  // 进行状态的更新
  override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) {
val msg = StatusUpdate(executorId, taskId, state, data)
driver match {
  case Some(driverRef) => driverRef.send(msg)
  case None => logWarning(s"Drop $msg because has not yet connected to driver")
}
  }
}

在上面的类当中,连接到driver当中,注册当前的业务执行进程,可以看到上面继承了 RpcEndpoint接口。
通过rpcEnv.asyncSetupEndpointRefByURI(driverUrl) 方法连接到driver,
rpcEnv 的实现类为NettyRpcEnv和AkkaRpcEnv 两个,也可以看到其实spark 的低层通讯方式为 netty和akka两种。
可以看到在该类中有 LaunchTask、KillTask 等回调方法,通过这个方法从driver中接收执行的任务。
同时通过创建 Executor 类,拉起执行线程。

下面为Executor类的核心代码

   // 起动任务进行执行了
  def launchTask(
  context: ExecutorBackend,
  taskId: Long,
  attemptNumber: Int,
  taskName: String,
  serializedTask: ByteBuffer): Unit = {
// 创建任务对象了
val tr = new TaskRunner(context, taskId = taskId, attemptNumber = attemptNumber, taskName,
  serializedTask)
runningTasks.put(taskId, tr)
threadPool.execute(tr)
  }

  // 中止任务
  def killTask(taskId: Long, interruptThread: Boolean): Unit = {
val tr = runningTasks.get(taskId)
if (tr != null) {
  tr.kill(interruptThread)
}
  }

任务的执行核心调用代码为

 override def run(): Unit = {
  // 任务的内存管理
  val taskMemoryManager = new TaskMemoryManager(env.memoryManager, taskId)
  val deserializeStartTime = System.currentTimeMillis()
  Thread.currentThread.setContextClassLoader(replClassLoader)
  val ser = env.closureSerializer.newInstance()
  logInfo(s"Running $taskName (TID $taskId)")
  execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
  var taskStart: Long = 0
  startGCTime = computeTotalGcTime()

  try {
// 返回依赖的jar或者其它依赖文件
val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)
// 把依赖的jar更新到classpath中
updateDependencies(taskFiles, taskJars)
task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
task.setTaskMemoryManager(taskMemoryManager)

// If this task has been killed before we deserialized it, let's quit now. Otherwise,
// continue executing the task.
if (killed) {
  // Throw an exception rather than returning, because returning within a try{} block
  // causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl
  // exception will be caught by the catch block, leading to an incorrect ExceptionFailure
  // for the task.
  throw new TaskKilledException
}

logDebug("Task " + taskId + "'s epoch is " + task.epoch)
env.mapOutputTracker.updateEpoch(task.epoch)

// Run the actual task and measure its runtime.
taskStart = System.currentTimeMillis()
var threwException = true
val (value, accumUpdates) = try {
  // 这里开始直接run任务了
  val res = task.run(
taskAttemptId = taskId,
attemptNumber = attemptNumber,
metricsSystem = env.metricsSystem)
  threwException = false
  res
} finally {
  // 清理一些block的 lock
  val releasedLocks = env.blockManager.releaseAllLocksForTask(taskId)
  // 清理一些申请的内存
  val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory()
  if (freedMemory > 0) {
val errMsg = s"Managed memory leak detected; size = $freedMemory bytes, TID = $taskId"
if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false) && !threwException) {
  throw new SparkException(errMsg)
} else {
  logError(errMsg)
}
  }

  if (releasedLocks.nonEmpty) {
val errMsg =
  s"${releasedLocks.size} block locks were not released by TID = $taskId:\n" +
  releasedLocks.mkString("[", ", ", "]")
if (conf.getBoolean("spark.storage.exceptionOnPinLeak", false) && !threwException) {
  throw new SparkException(errMsg)
} else {
  logError(errMsg)
}
  }
}
val taskFinish = System.currentTimeMillis()

// If the task has been killed, let's fail it.
if (task.killed) {
  throw new TaskKilledException
}

val resultSer = env.serializer.newInstance()
val beforeSerialization = System.currentTimeMillis()
// 执行结果序列化
val valueBytes = resultSer.serialize(value)
val afterSerialization = System.currentTimeMillis()
// metrics 的更新
for (m <- task.metrics) {
  // Deserialization happens in two parts: first, we deserialize a Task object, which
  // includes the Partition. Second, Task.run() deserializes the RDD and function to be run.
  m.setExecutorDeserializeTime(
(taskStart - deserializeStartTime) + task.executorDeserializeTime)
  // We need to subtract Task.run()'s deserialization time to avoid double-counting
  m.setExecutorRunTime((taskFinish - taskStart) - task.executorDeserializeTime)
  m.setJvmGCTime(computeTotalGcTime() - startGCTime)
  m.setResultSerializationTime(afterSerialization - beforeSerialization)
  m.updateAccumulators()
}
// 运行的结果包装类
val directResult = new DirectTaskResult(valueBytes, accumUpdates, task.metrics.orNull)
val serializedDirectResult = ser.serialize(directResult)
val resultSize = serializedDirectResult.limit

// directSend = sending directly back to the driver
val serializedResult: ByteBuffer = {
  if (maxResultSize > 0 && resultSize > maxResultSize) {
logWarning(s"Finished $taskName (TID $taskId). Result is larger than maxResultSize " +
  s"(${Utils.bytesToString(resultSize)} > ${Utils.bytesToString(maxResultSize)}), " +
  s"dropping it.")
// 结果太大了,直接掉弃
ser.serialize(new IndirectTaskResult[Any](TaskResultBlockId(taskId), resultSize))
  } else if (resultSize >= akkaFrameSize - AkkaUtils.reservedSizeBytes) {
// 超时akka一个frame的size时
val blockId = TaskResultBlockId(taskId)
// block管理
env.blockManager.putBytes(
  blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER)
logInfo(
  s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)")
ser.serialize(new IndirectTaskResult[Any](blockId, resultSize))
  } else {
logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result sent to driver")
serializedDirectResult
  }
}
// 更新状态
execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)

  } catch {
case ffe: FetchFailedException =>
  val reason = ffe.toTaskEndReason
  // 更新失败状态
  execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))

case _: TaskKilledException | _: InterruptedException if task.killed =>
  logInfo(s"Executor killed $taskName (TID $taskId)")
  execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))

case CausedBy(cDE: CommitDeniedException) =>
  val reason = cDE.toTaskEndReason
  execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))

case t: Throwable =>
  // Attempt to exit cleanly by informing the driver of our failure.
  // If anything goes wrong (or this was a fatal exception), we will delegate to
  // the default uncaught exception handler, which will terminate the Executor.
  logError(s"Exception in $taskName (TID $taskId)", t)

  val metrics: Option[TaskMetrics] = Option(task).flatMap { task =>
task.metrics.map { m =>
  m.setExecutorRunTime(System.currentTimeMillis() - taskStart)
  m.setJvmGCTime(computeTotalGcTime() - startGCTime)
  m.updateAccumulators()
  m
}
  }
  val serializedTaskEndReason = {
try {
  ser.serialize(new ExceptionFailure(t, metrics))
} catch {
  case _: NotSerializableException =>
// t is not serializable so just send the stacktrace
ser.serialize(new ExceptionFailure(t, metrics, false))
}
  }
  // 更新失败状态
  execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskEndReason)

  // Don't forcibly exit unless the exception was inherently fatal, to avoid
  // stopping other tasks unnecessarily.
  if (Utils.isFatalError(t)) {
SparkUncaughtExceptionHandler.uncaughtException(t)
  }

  } finally {
runningTasks.remove(taskId)
  }
}
  }

可以看到上面的代码当中,通过反序列化deserializeWithDependencies对象,去重新加载业务依赖的jar包到classpath环境变量中,
然后反序列化task 对象,当中用到taskMemoryManager 进行on_head和off_head的管理功能。
实现taks接口的有ShuffleMapTask,ResultTask等,这两个taks是执行spark的核心taks任务。
当任务task执行完成后,通过

  execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)

代码把结果告诉driver,

下面为通过http、hdfs和file等方式去拉取相关的依赖文件

/**
   * Download any missing dependencies if we receive a new set of files and JARs from the
   * SparkContext. Also adds any new JARs we fetched to the class loader.
   */
  private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) {
// 更新jar,加载jar到classloader中
lazy val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf)
synchronized {
  // Fetch missing dependencies
  for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp)
// Fetch file with useCache mode, close cache for local mode.
Utils.fetchFile(name, new File(SparkFiles.getRootDirectory()), conf,
  env.securityManager, hadoopConf, timestamp, useCache = !isLocal)
currentFiles(name) = timestamp
  }
  for ((name, timestamp) <- newJars) {
val localName = name.split("/").last
val currentTimeStamp = currentJars.get(name)
  .orElse(currentJars.get(localName))
  .getOrElse(-1L)
if (currentTimeStamp < timestamp) {
  logInfo("Fetching " + name + " with timestamp " + timestamp)
  // Fetch file with useCache mode, close cache for local mode.
  Utils.fetchFile(name, new File(SparkFiles.getRootDirectory()), conf,
env.securityManager, hadoopConf, timestamp, useCache = !isLocal)
  currentJars(name) = timestamp
  // Add it to our class loader
  val url = new File(SparkFiles.getRootDirectory(), localName).toURI.toURL
  if (!urlClassLoader.getURLs().contains(url)) {
logInfo("Adding " + url + " to class loader")
urlClassLoader.addURL(url)
  }
}
  }
}
  }

上面就是为核心的任务接收过程。

  1. 通过建立对driver的通过连接akka或者netty的方式
  2. 接收driver分派过来的任务
  3. 拉取任务依赖的文件和jar包
  4. 执行task任务
  5. 把结果回传回driver端
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值