最近遇到了spark structured streaming从kafka读取数据时,首个批处理查询读取的日志太多导致程序崩溃,查找资料未果,只能从连接kafka的jar包源码入手,查看是否可以限制单个批量读取的最大日志数,最终通过修改源码重新编译实现该功能。
spark-sql-kafka在实现流式处理数据的时候,每次先读取前一个batch的topic-partition-offset,然后通过KafkaConsumer读取最新的offset,读取这期间的各个分区的日志,封装为rdd[Row],然后进行各种处理。
这次主要修改了两个地方,一是在读取kafka数据的时候将每个分区的end-offset保存起来;二是在返回每个分区最新offset的时候,与已保存的前一次end-offset对比,如果大于程序能够处理的单个分区日志数量,则只返回前一次end-offset与单个分区日志数上限的和,表示maxnumber_per_partition_per_batch,这样spark每个批量最多处理的数据大小就是kafka的分区数量乘以单个分区的上限。
修改的对应文件为:
KafkaSource.scala
override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
// Make sure initialPartitionOffsets is initialized
initialPartitionOffsets
logInfo(s"GetBatch called with start = $start, end = $end")
val untilPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(end)
// On recovery, getBatch will get called before getOffset
if (currentPartitionOffsets.isEmpty) {
currentPartitionOffsets = Some(untilPartitionOffsets)
}
OffsetCache.pushOffsetInfo(end.json())
if (start.isDefined && start.get == end) {
return sqlContext.internalCreateDataFrame(
sqlContext.sparkContext.emptyRDD, schema, isStreaming = true)
}
val fromPartitionOffsets = start match {
case Some(prevBatchEndOffset) =>
KafkaSourceOffset.getPartitionOffsets(prevBatchEndOffset)
case None =>
initialPartitionOffsets
}
...
}
KafkaOffsetReader.scala
def fetchLatestOffsets(): Map[TopicPartition, Long] = runUninterruptibly {
withRetriesWithoutInterrupt {
// Poll to get the latest assigned partitions
consumer.poll(0)
val partitions = consumer.assignment()
consumer.pause(partitions)
logDebug(s"Partitions assigned to consumer: $partitions. Seeking to the end.")
consumer.seekToEnd(partitions)
val batchSize = 10000
val partitionOffsets = partitions.asScala.map(p => p -> consumer.position(p)).toMap.map(p => {
var offset: Long = p._2
if (OffsetCache.topic_patition_offset_map.contains(p._1.topic())) {
val partitionOffsetMap = OffsetCache.topic_patition_offset_map(p._1.topic())
if (partitionOffsetMap.contains(p._1.partition())) {
val preOffset = partitionOffsetMap(p._1.partition())
if (offset - preOffset > batchSize) {
offset = preOffset + batchSize
}
}
}
(p._1, offset)
})
logDebug(s"Got latest offsets for partition : $partitionOffsets")
partitionOffsets
}
}
新加的文件OffsetCache.scala
package org.apache.spark.sql.kafka010
import com.google.gson.Gson
import scala.collection.JavaConversions._
import scala.collection.mutable
object OffsetCache {
val gson = new Gson
val topic_patition_offset_map = mutable.Map[String, mutable.Map[Int, Long]]()
def pushOffsetInfo(json: String): Unit = {
val topicPartOffsetMap = gson.fromJson(json, classOf[java.util.Map[String, java.util.Map[String, Double]]])
topicPartOffsetMap.foreach(topic_map => {
val partition_offset_map: mutable.Map[Int, Long] = topic_patition_offset_map.getOrElse(topic_map._1, mutable.Map[Int, Long]())
topic_map._2.foreach(partition_offset => {
partition_offset_map.put(partition_offset._1.toInt, partition_offset._2.toLong)
})
topic_patition_offset_map.put(topic_map._1, partition_offset_map)
})
println(s"topic_patition_offset_map $topic_patition_offset_map")
}
}