在前面我们分析了shuffle在map阶段的写过程,这一篇我们继续分析读过程。shuffle的读过程发生的宽依赖的RDD(如ShuffledRDD)的compute方法被调用的时候。所以先来看ShuffledRDD的compute()方法的源码:
// ShuffledRDD.scala
override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
.read()
.asInstanceOf[Iterator[(K, C)]]
}
首先通过ShuffleManager获取一个读取器。通过读取器来读取数据。
// SortShuffleManager.scala
override def getReader[K, C](
handle: ShuffleHandle,
startPartition: Int,
endPartition: Int,
context: TaskContext): ShuffleReader[K, C] = {
new BlockStoreShuffleReader(
handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context)
}
下面就是真正的数据读取的逻辑了。调用BlockStoreShuffleReader.read() 方法。
总结一下这个方法的主要步骤:
- 获取一个包装的迭代器ShuffleBlockFetcherIterator,它迭代的元素是blockId和这个block对应的读取流,很显然这个类就是实现reduce阶段数据读取的关键
- 将原始读取流转换成反序列化后的迭代器
- 将迭代器转换成能够统计度量值的迭代器,这一系列的转换和java中对于流的各种装饰器很类似
- 将迭代器包装成能够相应中断的迭代器。每读一条数据就会检查一下任务有没有被杀死,这种做法是为了尽量及时地响应杀死任务的请求,比如从driver端发来杀死任务的消息。
- 利用聚合器对结果进行聚合。这里再次利用了AppendonlyMap这个数据结构,前面shuffle写阶段也用到这个数据结构,它的内部是一个以数组作为底层数据结构的,以线性探测法线性的hash表。
- 最后对结果进行排序。
// BlockStoreShuffleReader.scala
override def read(): Iterator[Product2[K, C]] = {
// 获取一个包装的迭代器,它迭代的元素是blockId和这个block对应的读取流
val wrappedStreams = new ShuffleBlockFetcherIterator( // 1
context,
// 如果没有启用外部shuffle服务,就是BlockTransferService
blockManager.shuffleClient,
blockManager,
// 通过mapOutputTracker组件获取每个分区对应的数据block的物理位置
mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
serializerManager.wrapStream,
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
// 获取几个配置参数
SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue),
SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM),
SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true))
val serializerInstance = dep.serializer.newInstance()
// Create a key/value iterator for each stream
// 将原始读取流转换成反序列化后的迭代器
val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) =>
// Note: the asKeyValueIterator below wraps a key/value iterator inside of a
// NextIterator. The NextIterator makes sure that close() is called on the
// underlying InputStream when all records have been read.
serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
}
// Update the context task metrics for each record read.
val readMetrics = context.taskMetrics.createTempShuffleReadMetrics()
// 转换成能够统计度量值的迭代器,这一系列的转换和java中对于流的各种装饰器很类似
val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
recordIter.map { record =>
readMetrics.incRecordsRead(1)
record
},
context.taskMetrics().mergeShuffleReadMetrics())
// An interruptible iterator must be used here in order to support task cancellation
// 每读一条数据就会检查一下任务有没有被杀死,
// 这种做法是为了尽量及时地响应杀死任务的请求,比如从driver端发来杀死任务的消息
val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
// 利用聚合器对结果进行聚合
if (dep.mapSideCombine) {
// We are reading values that are already combined
val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
} else {
// We don't know the value type, but also don't care -- the dependency *should*
// have made sure its compatible w/ this aggregator, which will convert the value
// t