上一篇传送门:https://blog.youkuaiyun.com/cw1254332663/article/details/95327497
学习的总结,不对之处请大家及时指正,谢谢啦!
上回书到,CoarseGrainedSchedulerBackend中的内部类DriverEndpoint把TaskDescription序列化,封装到LaunchTask样例类中。
CoarseGrainedExecutorBackend中有一个方法用于接受消息:
override def receive: PartialFunction[Any, Unit] = {
// 接收到这个消息才会new一个executor
case RegisteredExecutor =>
logInfo("Successfully registered with driver")
try {
executor = new Executor(executorId, hostname, env, userClassPath, isLocal = false)
} catch {
case NonFatal(e) =>
exitExecutor(1, "Unable to create executor due to " + e.getMessage, e)
}
case RegisterExecutorFailed(message) =>
exitExecutor(1, "Slave registration failed: " + message)
// 在这里我们模式匹配到这里
case LaunchTask(data) =>
// 判断Executor是否存在
if (executor == null) {
exitExecutor(1, "Received LaunchTask command but executor was null")
} else {
// 反序列化
val taskDesc = TaskDescription.decode(data.value)
logInfo("Got assigned task " + taskDesc.taskId)
// 调用executor的执行方法来执行task
executor.launchTask(this, taskDesc)
}
// 杀死
case KillTask(taskId, _, interruptThread, reason) =>
if (executor == null) {
exitExecutor(1, "Received KillTask command but executor was null")
} else {
executor.killTask(taskId, interruptThread, reason)
}
case StopExecutor =>
stopping.set(true)
logInfo("Driver commanded a shutdown")
self.send(Shutdown)
case Shutdown =>
stopping.set(true)
new Thread("CoarseGrainedExecutorBackend-stop-executor") {
override def run(): Unit = {
executor.stop()
}
}.start()
}
看一下executor.launchTask方法:
def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = {
// 把ExecutorBackend和TaskDescription封装到Runnable中
val tr = new TaskRunner(context, taskDescription)
// 其中维护着正在运行的task的信息
runningTasks.put(taskDescription.taskId, tr)
// 把Runnable放到线程池中去执行
threadPool.execute(tr)
}
我们来看看TaskRunner中run方法的逻辑:
override def run(): Unit = {
/*
这里有一堆的信息获取代码
......
*/
// 获取序列化工具
val ser = env.closureSerializer.newInstance()
logInfo(s"Running $taskName (TID $taskId)")
// 向集群调度程序更新状态
execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
var taskStart: Long = 0
var taskStartCpu: Long = 0
startGCTime = computeTotalGcTime()
try {
Executor.taskDeserializationProps.set(taskDescription.properties)
// 如果从SparkContext接收到一组新的文件和JAR,将会下载所有缺少的依赖项。
// 并且用类加载器去加载新的jar。
updateDependencies(taskDescription.addedFiles, taskDescription.addedJars)
// 反序列化我们的task任务
task = ser.deserialize[Task[Any]](
taskDescription.serializedTask, Thread.currentThread.getContextClassLoader)
task.localProperties = taskDescription.properties
task.setTaskMemoryManager(taskMemoryManager)
// 如果在反序列化此任务之前已将其终止,那么现在就会退出。否则,继续执行任务。
val killReason = reasonIfKilled
if (killReason.isDefined) {
throw new TaskKilledException(killReason.get)
}
logDebug("Task " + taskId + "'s epoch is " + task.epoch)
env.mapOutputTracker.updateEpoch(task.epoch)
// 运行实际的task,并且测量运行时情况
taskStart = System.currentTimeMillis()
taskStartCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
threadMXBean.getCurrentThreadCpuTime
} else 0L
var threwException = true
val value = try {
// 这里是主要的方法,运行task
val res = task.run(
taskAttemptId = taskId,
attemptNumber = taskDescription.attemptNumber,
metricsSystem = env.metricsSystem)
threwException = false
res
}
}
......
}
继续跟进task.run()方法中:
final def run(
taskAttemptId: Long,
attemptNumber: Int,
metricsSystem: MetricsSystem): T = {
SparkEnv.get.blockManager.registerTask(taskAttemptId)
//这里初始化了一些类
......
try {
//在这里调用了runTask方法
runTask(context)
}
......
}
跟进shuffleMapTask.runTask方法:
override def runTask(context: TaskContext): MapStatus = {
val threadMXBean = ManagementFactory.getThreadMXBean
val deserializeStartTime = System.currentTimeMillis()
val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
threadMXBean.getCurrentThreadCpuTime
} else 0L
val ser = SparkEnv.get.closureSerializer.newInstance()
// 使用广播变量反序列化RDD。taskBinary就是Broadcast[Array[Byte]]
val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](
ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
_executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime
_executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime
} else 0L
var writer: ShuffleWriter[Any, Any] = null
try {
// 获取shuffleManager,
// 有三种shuffleMananger:Hash、Sort(默认)、Tungsten-sort,之后会出篇文章专门讲解。
val manager = SparkEnv.get.shuffleManager
// 获取writer对象
writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
// 这里可以分为两步
// 1. rdd的计算
// 2. 计算结果的存储
writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
writer.stop(success = true).get
} catch {
case e: Exception =>
try {
if (writer != null) {
writer.stop(success = false)
}
} catch {
case e: Exception =>
log.debug("Could not stop writer", e)
}
throw e
}
}
我们先看第一步RDD的计算:
final def iterator(split: Partition, context: TaskContext): Iterator[T] = {
// 判断本RDD是否已经被持久化过
if (storageLevel != StorageLevel.NONE) {
// 获取或计算RDD的一个分区
getOrCompute(split, context)
} else {
// 计算或读取checkpoint
computeOrReadCheckpoint(split, context)
}
}
其中getOrCompute方法:
private[spark] def getOrCompute(partition: Partition, context: TaskContext): Iterator[T] = {
// 获取rdd的缓存块id
val blockId = RDDBlockId(id, partition.index)
var readCachedBlock = true
// 如果指定的缓存块存在,则检索该缓存块
// 否则调用提供的“makeIterator”方法来计算该缓存块、保持该缓存块并返回其值。
// 如果缓存块存在,则返回blockresult;如果缓存块不存在,则返回迭代器。
SparkEnv.get.blockManager.getOrElseUpdate(blockId, storageLevel, elementClassTag, () => {
readCachedBlock = false
computeOrReadCheckpoint(partition, context)
}) match {
case Left(blockResult) =>
if (readCachedBlock) {
val existingMetrics = context.taskMetrics().inputMetrics
existingMetrics.incBytesRead(blockResult.bytes)
new InterruptibleIterator[T](context, blockResult.data.asInstanceOf[Iterator[T]]) {
override def next(): T = {
existingMetrics.incRecordsRead(1)
delegate.next()
}
}
} else {
new InterruptibleIterator(context, blockResult.data.asInstanceOf[Iterator[T]])
}
case Right(iter) =>
new InterruptibleIterator(context, iter.asInstanceOf[Iterator[T]])
}
}
computeOrReadCheckpoint方法:
private[spark] def computeOrReadCheckpoint(split: Partition, context: TaskContext): Iterator[T] =
{
if (isCheckpointedAndMaterialized) {
//调用父RDD的iterator方法
firstParent[T].iterator(split, context)
} else {
//调用计算逻辑进行计算
compute(split, context)
}
}
继续跟进compute(非shuffle算子生成的MapPartitionsRDD)方法:
override def compute(split: Partition, context: TaskContext): Iterator[U] =
// 此处调用父RDD的iterator方法
f(context, split.index, firstParent[T].iterator(split, context))
而shuffle算子生成的RDD的compute方法:
override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = {
val shuffledRowPartition = split.asInstanceOf[ShuffledRowRDDPartition]
// 获取reader对象
val reader =
SparkEnv.get.shuffleManager.getReader(
dependency.shuffleHandle,
shuffledRowPartition.startPreShufflePartitionIndex,
shuffledRowPartition.endPreShufflePartitionIndex,
context)
// 调用read()方法,该方法向BlockManager获取上个rdd生成的中间文件的位置
reader.read().asInstanceOf[Iterator[Product2[Int, InternalRow]]].map(_._2)
}
我们再看第二步writer.write():
override def write(records: Iterator[Product2[K, V]]): Unit = {
sorter = if (dep.mapSideCombine) {
require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
// ExternalSorter中维护着两个内存结构,如果是聚合类算子如reduceByKey
// 在溢出之前先放到Map结构中,在每次写入Map时需要判断大小是否到达一定阈值
// 到达之后溢出到缓冲区并清空该数据结构中的数据
// 缓冲区(默认32K)满了之后落地成文件
// 如果是join类普通的shuffle算子则使用Array结构
new ExternalSorter[K, V, C](
context, dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
} else {
// 在这种情况下,我们既不向ExternalSorter传递aggregate,也不向ExternalSorter传递sort
// 因为我们不用关心key是否在每个分区中排序;
// 如果正在运行的操作是sortbykey,那么将在reduce端进行排序。
new ExternalSorter[K, V, V](
context, aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer)
}
// 向内存结构(Map或Array)中写入数据,每次写入都需要判断是否要溢出
sorter.insertAll(records)
// 在sortShuffleManager中,一个task输出一个文件和一个索引文件
val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
// 目录中创建一个临时文件
val tmp = Utils.tempFileWith(output)
try {
val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
// 将添加到ExternalSorter中的所有数据写入磁盘
val partitionLengths = sorter.writePartitionedFile(blockId, tmp)
// 写一个索引文件,每个块的偏移量加上输出文件末尾的最终偏移量。
// 这被用于getBlockData获取数据时确定每个块的位置。
shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp)
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
} finally {
if (tmp.exists() && !tmp.delete()) {
logError(s"Error while deleting temp file ${tmp.getAbsolutePath}")
}
}
}
这里有点多,以后慢慢的补充......