spark 业务执行进程
spark的task的业务执行进程的入口类为 CoarseGrainedExecutorBackend
在CoarseGrainedExecutorBackend进程当中,同时也拉起了 WorkerWatcher线程来和当前机器的work进程进行通信
// 这里就拉起了 CoarseGrainedExecutorBackend 线程了,这个线程对象和spark driver进行通信连接,接收执行任务
env.rpcEnv.setupEndpoint("Executor", new CoarseGrainedExecutorBackend(
env.rpcEnv, driverUrl, executorId, sparkHostPort, cores, userClassPath, env))
// 这里实现了和 spark work进程的通信,当spark work进程 down掉时,这个线程监听到,然后也自动中止了
workerUrl.foreach { url =>
env.rpcEnv.setupEndpoint("WorkerWatcher", new WorkerWatcher(env.rpcEnv, url))
}
通过接收相关的参数调用如下的入口,然后进行创建相关的类
private[spark] class CoarseGrainedExecutorBackend(
override val rpcEnv: RpcEnv,
driverUrl: String,
executorId: String,
hostPort: String,
cores: Int,
userClassPath: Seq[URL],
env: SparkEnv)
extends ThreadSafeRpcEndpoint with ExecutorBackend with Logging {
var executor: Executor = null
@volatile var driver: Option[RpcEndpointRef] = None
// If this CoarseGrainedExecutorBackend is changed to support multiple threads, then this may need
// to be changed so that we don't share the serializer instance across threads
private[this] val ser: SerializerInstance = env.closureSerializer.newInstance()
override def onStart() {
logInfo("Connecting to driver: " + driverUrl)
// 这里开始连接到driver了
rpcEnv.asyncSetupEndpointRefByURI(driverUrl).flatMap { ref =>
// This is a very fast action so we can use "ThreadUtils.sameThread"
driver = Some(ref)
// 向driver进行了注册了
ref.ask[RegisterExecutorResponse](
RegisterExecutor(executorId, self, hostPort, cores, extractLogUrls))
}(ThreadUtils.sameThread).onComplete {
// This is a very fast action so we can use "ThreadUtils.sameThread"
case Success(msg) => Utils.tryLogNonFatalError {
// 注册完成后调用自身的消息
Option(self).foreach(_.send(msg)) // msg must be RegisterExecutorResponse
}
case Failure(e) => {
logError(s"Cannot register with driver: $driverUrl", e)
System.exit(1)
}
}(ThreadUtils.sameThread)
}
def extractLogUrls: Map[String, String] = {
val prefix = "SPARK_LOG_URL_"
sys.env.filterKeys(_.startsWith(prefix))
.map(e => (e._1.substring(prefix.length).toLowerCase, e._2))
}
// 这里就是消息的接收器了
override def receive: PartialFunction[Any, Unit] = {
case RegisteredExecutor(hostname) =>
logInfo("Successfully registered with driver")
// 和driver注册成功了
executor = new Executor(executorId, hostname, env, userClassPath, isLocal = false)
case RegisterExecutorFailed(message) =>
logError("Slave registration failed: " + message)
System.exit(1)
case LaunchTask(data) =>
if (executor == null) {
logError("Received LaunchTask command but executor was null")
System.exit(1)
} else {
// 执行的任务反序列化
val taskDesc = ser.deserialize[TaskDescription](data.value)
logInfo("Got assigned task " + taskDesc.taskId)
// 执行任务了
executor.launchTask(this, taskId = taskDesc.taskId, attemptNumber = taskDesc.attemptNumber,
taskDesc.name, taskDesc.serializedTask)
}
case KillTask(taskId, _, interruptThread) =>
if (executor == null) {
logError("Received KillTask command but executor was null")
System.exit(1)
} else {
executor.killTask(taskId, interruptThread)
}
case StopExecutor =>
logInfo("Driver commanded a shutdown")
// Cannot shutdown here because an ack may need to be sent back to the caller. So send
// a message to self to actually do the shutdown.
self.send(Shutdown)
case Shutdown =>
executor.stop()
stop()
rpcEnv.shutdown()
}
override def onDisconnected(remoteAddress: RpcAddress): Unit = {
if (driver.exists(_.address == remoteAddress)) {
logError(s"Driver $remoteAddress disassociated! Shutting down.")
// driver主动断开连接,所以当前进程也进行了退出了
System.exit(1)
} else {
logWarning(s"An unknown ($remoteAddress) driver disconnected.")
}
}
// 进行状态的更新
override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) {
val msg = StatusUpdate(executorId, taskId, state, data)
driver match {
case Some(driverRef) => driverRef.send(msg)
case None => logWarning(s"Drop $msg because has not yet connected to driver")
}
}
}
在上面的类当中,连接到driver当中,注册当前的业务执行进程,可以看到上面继承了 RpcEndpoint接口。
通过rpcEnv.asyncSetupEndpointRefByURI(driverUrl) 方法连接到driver,
rpcEnv 的实现类为NettyRpcEnv和AkkaRpcEnv 两个,也可以看到其实spark 的低层通讯方式为 netty和akka两种。
可以看到在该类中有 LaunchTask、KillTask 等回调方法,通过这个方法从driver中接收执行的任务。
同时通过创建 Executor 类,拉起执行线程。
下面为Executor类的核心代码
// 起动任务进行执行了
def launchTask(
context: ExecutorBackend,
taskId: Long,
attemptNumber: Int,
taskName: String,
serializedTask: ByteBuffer): Unit = {
// 创建任务对象了
val tr = new TaskRunner(context, taskId = taskId, attemptNumber = attemptNumber, taskName,
serializedTask)
runningTasks.put(taskId, tr)
threadPool.execute(tr)
}
// 中止任务
def killTask(taskId: Long, interruptThread: Boolean): Unit = {
val tr = runningTasks.get(taskId)
if (tr != null) {
tr.kill(interruptThread)
}
}
任务的执行核心调用代码为
override def run(): Unit = {
// 任务的内存管理
val taskMemoryManager = new TaskMemoryManager(env.memoryManager, taskId)
val deserializeStartTime = System.currentTimeMillis()
Thread.currentThread.setContextClassLoader(replClassLoader)
val ser = env.closureSerializer.newInstance()
logInfo(s"Running $taskName (TID $taskId)")
execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
var taskStart: Long = 0
startGCTime = computeTotalGcTime()
try {
// 返回依赖的jar或者其它依赖文件
val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)
// 把依赖的jar更新到classpath中
updateDependencies(taskFiles, taskJars)
task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
task.setTaskMemoryManager(taskMemoryManager)
// If this task has been killed before we deserialized it, let's quit now. Otherwise,
// continue executing the task.
if (killed) {
// Throw an exception rather than returning, because returning within a try{} block
// causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl
// exception will be caught by the catch block, leading to an incorrect ExceptionFailure
// for the task.
throw new TaskKilledException
}
logDebug("Task " + taskId + "'s epoch is " + task.epoch)
env.mapOutputTracker.updateEpoch(task.epoch)
// Run the actual task and measure its runtime.
taskStart = System.currentTimeMillis()
var threwException = true
val (value, accumUpdates) = try {
// 这里开始直接run任务了
val res = task.run(
taskAttemptId = taskId,
attemptNumber = attemptNumber,
metricsSystem = env.metricsSystem)
threwException = false
res
} finally {
// 清理一些block的 lock
val releasedLocks = env.blockManager.releaseAllLocksForTask(taskId)
// 清理一些申请的内存
val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory()
if (freedMemory > 0) {
val errMsg = s"Managed memory leak detected; size = $freedMemory bytes, TID = $taskId"
if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false) && !threwException) {
throw new SparkException(errMsg)
} else {
logError(errMsg)
}
}
if (releasedLocks.nonEmpty) {
val errMsg =
s"${releasedLocks.size} block locks were not released by TID = $taskId:\n" +
releasedLocks.mkString("[", ", ", "]")
if (conf.getBoolean("spark.storage.exceptionOnPinLeak", false) && !threwException) {
throw new SparkException(errMsg)
} else {
logError(errMsg)
}
}
}
val taskFinish = System.currentTimeMillis()
// If the task has been killed, let's fail it.
if (task.killed) {
throw new TaskKilledException
}
val resultSer = env.serializer.newInstance()
val beforeSerialization = System.currentTimeMillis()
// 执行结果序列化
val valueBytes = resultSer.serialize(value)
val afterSerialization = System.currentTimeMillis()
// metrics 的更新
for (m <- task.metrics) {
// Deserialization happens in two parts: first, we deserialize a Task object, which
// includes the Partition. Second, Task.run() deserializes the RDD and function to be run.
m.setExecutorDeserializeTime(
(taskStart - deserializeStartTime) + task.executorDeserializeTime)
// We need to subtract Task.run()'s deserialization time to avoid double-counting
m.setExecutorRunTime((taskFinish - taskStart) - task.executorDeserializeTime)
m.setJvmGCTime(computeTotalGcTime() - startGCTime)
m.setResultSerializationTime(afterSerialization - beforeSerialization)
m.updateAccumulators()
}
// 运行的结果包装类
val directResult = new DirectTaskResult(valueBytes, accumUpdates, task.metrics.orNull)
val serializedDirectResult = ser.serialize(directResult)
val resultSize = serializedDirectResult.limit
// directSend = sending directly back to the driver
val serializedResult: ByteBuffer = {
if (maxResultSize > 0 && resultSize > maxResultSize) {
logWarning(s"Finished $taskName (TID $taskId). Result is larger than maxResultSize " +
s"(${Utils.bytesToString(resultSize)} > ${Utils.bytesToString(maxResultSize)}), " +
s"dropping it.")
// 结果太大了,直接掉弃
ser.serialize(new IndirectTaskResult[Any](TaskResultBlockId(taskId), resultSize))
} else if (resultSize >= akkaFrameSize - AkkaUtils.reservedSizeBytes) {
// 超时akka一个frame的size时
val blockId = TaskResultBlockId(taskId)
// block管理
env.blockManager.putBytes(
blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER)
logInfo(
s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)")
ser.serialize(new IndirectTaskResult[Any](blockId, resultSize))
} else {
logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result sent to driver")
serializedDirectResult
}
}
// 更新状态
execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
} catch {
case ffe: FetchFailedException =>
val reason = ffe.toTaskEndReason
// 更新失败状态
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
case _: TaskKilledException | _: InterruptedException if task.killed =>
logInfo(s"Executor killed $taskName (TID $taskId)")
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
case CausedBy(cDE: CommitDeniedException) =>
val reason = cDE.toTaskEndReason
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
case t: Throwable =>
// Attempt to exit cleanly by informing the driver of our failure.
// If anything goes wrong (or this was a fatal exception), we will delegate to
// the default uncaught exception handler, which will terminate the Executor.
logError(s"Exception in $taskName (TID $taskId)", t)
val metrics: Option[TaskMetrics] = Option(task).flatMap { task =>
task.metrics.map { m =>
m.setExecutorRunTime(System.currentTimeMillis() - taskStart)
m.setJvmGCTime(computeTotalGcTime() - startGCTime)
m.updateAccumulators()
m
}
}
val serializedTaskEndReason = {
try {
ser.serialize(new ExceptionFailure(t, metrics))
} catch {
case _: NotSerializableException =>
// t is not serializable so just send the stacktrace
ser.serialize(new ExceptionFailure(t, metrics, false))
}
}
// 更新失败状态
execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskEndReason)
// Don't forcibly exit unless the exception was inherently fatal, to avoid
// stopping other tasks unnecessarily.
if (Utils.isFatalError(t)) {
SparkUncaughtExceptionHandler.uncaughtException(t)
}
} finally {
runningTasks.remove(taskId)
}
}
}
可以看到上面的代码当中,通过反序列化deserializeWithDependencies对象,去重新加载业务依赖的jar包到classpath环境变量中,
然后反序列化task 对象,当中用到taskMemoryManager 进行on_head和off_head的管理功能。
实现taks接口的有ShuffleMapTask,ResultTask等,这两个taks是执行spark的核心taks任务。
当任务task执行完成后,通过
execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
代码把结果告诉driver,
下面为通过http、hdfs和file等方式去拉取相关的依赖文件
/**
* Download any missing dependencies if we receive a new set of files and JARs from the
* SparkContext. Also adds any new JARs we fetched to the class loader.
*/
private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) {
// 更新jar,加载jar到classloader中
lazy val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf)
synchronized {
// Fetch missing dependencies
for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp)
// Fetch file with useCache mode, close cache for local mode.
Utils.fetchFile(name, new File(SparkFiles.getRootDirectory()), conf,
env.securityManager, hadoopConf, timestamp, useCache = !isLocal)
currentFiles(name) = timestamp
}
for ((name, timestamp) <- newJars) {
val localName = name.split("/").last
val currentTimeStamp = currentJars.get(name)
.orElse(currentJars.get(localName))
.getOrElse(-1L)
if (currentTimeStamp < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp)
// Fetch file with useCache mode, close cache for local mode.
Utils.fetchFile(name, new File(SparkFiles.getRootDirectory()), conf,
env.securityManager, hadoopConf, timestamp, useCache = !isLocal)
currentJars(name) = timestamp
// Add it to our class loader
val url = new File(SparkFiles.getRootDirectory(), localName).toURI.toURL
if (!urlClassLoader.getURLs().contains(url)) {
logInfo("Adding " + url + " to class loader")
urlClassLoader.addURL(url)
}
}
}
}
}
上面就是为核心的任务接收过程。
- 通过建立对driver的通过连接akka或者netty的方式
- 接收driver分派过来的任务
- 拉取任务依赖的文件和jar包
- 执行task任务
- 把结果回传回driver端