updateStateByKey

本文深入解析Spark Streaming中updateStateByKey与mapWithState两种状态管理方式的底层实现原理,对比其性能差异,提供完整示例代码。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

updateStateByKey操作允许您在使用新的信息持续更新时保持任意状态。
1、定义状态 - 状态可以是任意数据类型。
2、定义状态更新功能 - 使用函数指定如何使用上一个状态更新状态,并从输入流中指定新值。

如何使用该函数,spark文档写的很模糊,网上资料也不够详尽,自己翻阅源码总结一下,并给一个完整的例子
updateStateBykey函数有6种重载函数:
1、只传入一个更新函数,最简单的一种。
更新函数两个参数Seq[V], Option[S],前者是每个key新增的值的集合,后者是当前保存的状态,

def updateStateByKey[S: ClassTag](
    updateFunc: (Seq[V], Option[S]) => Option[S]
  ): DStream[(K, S)] = ssc.withScope {
  updateStateByKey(updateFunc, defaultPartitioner())
}

例如,对于wordcount,我们可以这样定义更新函数:

(values:Seq[Int],state:Option[Int])=>{
  //创建一个变量,用于记录单词出现次数
  var newValue=state.getOrElse(0) //getOrElse相当于if....else.....
  for(value <- values){
    newValue +=value //将单词出现次数累计相加
  }
  Option(newValue)
}
def updateFunction(currValues:Seq[Int],preValue:Option[Int]): Option[Int] = {
       val currValueSum = currValues.sum
        //上面的Int类型都可以用对象类型替换
        Some(currValueSum + preValue.getOrElse(0)) //当前值的和加上历史值
    }
    kafkaStream.map(r => (r._2,1)).updateStateByKey(updateFunction _)

这里的updateFunction方法就是需要我们自己去实现的状态跟新的逻辑,currValues就是当前批次的所有值,preValue是历史维护的状态,updateStateByKey返回的是包含历史所有状态信息的DStream,下面我们来看底层是怎么实现状态的管理的,通过跟踪源码看到最核心的实现方法:

  private [this] def computeUsingPreviousRDD(
      batchTime: Time,
      parentRDD: RDD[(K, V)],
      prevStateRDD: RDD[(K, S)]) = {
    // Define the function for the mapPartition operation on cogrouped RDD;
    // first map the cogrouped tuple to tuples of required type,
    // and then apply the update function
    val updateFuncLocal = updateFunc
    val finalFunc = (iterator: Iterator[(K, (Iterable[V], Iterable[S]))]) => {
      val i = iterator.map { t =>
        val itr = t._2._2.iterator
        val headOption = if (itr.hasNext) Some(itr.next()) else None
        (t._1, t._2._1.toSeq, headOption)
      }
      updateFuncLocal(batchTime, i)
    }
    val cogroupedRDD = parentRDD.cogroup(prevStateRDD, partitioner)
    val stateRDD = cogroupedRDD.mapPartitions(finalFunc, preservePartitioning)
    Some(stateRDD)
  }

可以看到是将parentRDD和preStateRDD进行co-group,然后将finalFunc方法作用于每个Partition,看到finalFunc方法的实现里面(t._1, t._2._1.toSeq, headOption)这样的形式,(key,currValues,preValue)这不就是和我们需要自己实现的updateFun类似的结构吗,是的没错,我们的方法已经被包装了一次:

def updateStateByKey[S: ClassTag](
      updateFunc: (Seq[V], Option[S]) => Option[S],
      partitioner: Partitioner
    ): DStream[(K, S)] = ssc.withScope {
    val cleanedUpdateF = sparkContext.clean(updateFunc)
    val newUpdateFunc = (iterator: Iterator[(K, Seq[V], Option[S])]) => {
      iterator.flatMap(t => cleanedUpdateF(t._2, t._3).map(s => (t._1, s)))
    }
    updateStateByKey(newUpdateFunc, partitioner, true)
  }

可以知道每次调用updateStateByKey都会将旧的状态RDD和当前batch的RDD进行co-group来得到一个新的状态RDD,即使真正需要跟新的数据只有1条也需要将两个RDD进行cogroup,所有的数据都会被计算一遍,而且随着状态的不断增加,运行速度会越来越慢。

为了解决这一问题,mapWithState应运而生。

mapWithState

   val initialRDD = ssc.sparkContext.parallelize(List[(String, Int)]())
    //自定义mappingFunction,累加单词出现的次数并更新状态
    val mappingFunc = (word: String, count: Option[Int], state: State[Int]) => {
      val sum = count.getOrElse(0) + state.getOption.getOrElse(0)
      val output = (word, sum)
      state.update(sum)
      output
    }
    //调用mapWithState进行管理流数据的状态
    kafkaStream.map(r => (r._2,1)).mapWithState(StateSpec.function(mappingFunc).initialState(initialRDD)).print()

这里的initialRDD就是初始化状态,updateStateByKey也有对应的API。这里的mappingFun也是需要我们自己实现的状态跟新逻辑,调用state.update()就是对状态的跟新,output就是通过mapWithState后返回的DStream中的数据形式。注意这里不是直接传入的mappingFunc函数,而是一个StateSpec 的对象,其实也是对函数的一个包装而已。接下来我们跟踪源码看看是怎么实现状态的管理的,会创建一个MapWithStateDStreamImpl实例:

def mapWithState[StateType: ClassTag, MappedType: ClassTag](
      spec: StateSpec[K, V, StateType, MappedType]
    ): MapWithStateDStream[K, V, StateType, MappedType] = {
    new MapWithStateDStreamImpl[K, V, StateType, MappedType](
      self,
      spec.asInstanceOf[StateSpecImpl[K, V, StateType, MappedType]]
    )
  }

当然是要看看其compute方法是怎么实现的:

 private val internalStream =
    new InternalMapWithStateDStream[KeyType, ValueType, StateType, MappedType](dataStream, spec)
 
  override def compute(validTime: Time): Option[RDD[MappedType]] = {
    internalStream.getOrCompute(validTime).map { _.flatMap[MappedType] { _.mappedData } }
  }

compute方法又把处理逻辑给了internalStream:InternalMapWithStateDStream,继续看InternalMapWithStateDStream的compute方法主要处理逻辑:

override def compute(validTime: Time): Option[RDD[MapWithStateRDDRecord[K, S, E]]] = {
    // Get the previous state or create a new empty state RDD
    val prevStateRDD = getOrCompute(validTime - slideDuration) match {
      case Some(rdd) =>
        if (rdd.partitioner != Some(partitioner)) {
          // If the RDD is not partitioned the right way, let us repartition it using the
          // partition index as the key. This is to ensure that state RDD is always partitioned
          // before creating another state RDD using it
          MapWithStateRDD.createFromRDD[K, V, S, E](
            rdd.flatMap { _.stateMap.getAll() }, partitioner, validTime)
        } else {
          rdd
        }
      case None =>
        MapWithStateRDD.createFromPairRDD[K, V, S, E](
          spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)),
          partitioner,
          validTime
        )
    }

    // Compute the new state RDD with previous state RDD and partitioned data RDD
    // Even if there is no data RDD, use an empty one to create a new state RDD
    val dataRDD = parent.getOrCompute(validTime).getOrElse {
      context.sparkContext.emptyRDD[(K, V)]
    }
    val partitionedDataRDD = dataRDD.partitionBy(partitioner)
    val timeoutThresholdTime = spec.getTimeoutInterval().map { interval =>
      (validTime - interval).milliseconds
    }
    Some(new MapWithStateRDD(
      prevStateRDD, partitionedDataRDD, mappingFunction, validTime, timeoutThresholdTime))
  }

先后获取prevStateRDD和parentRDD,并且保证使用的是同样的partitioner,接着以两个rdd为参数、自定义的mappingFunction函数、以及key的超时时间等为参数又创建了MapWithStateRDD,该RDD继承了RDD[MapWithStateRDDRecord[K, S, E]],MapWithStateRDD中的数据都是MapWithStateRDDRecord对象,每个分区对应一个对象来保存状态(这就是为什么两个RDD需要用同一个Partitioner),看看MapWithStateRDD的compute方法:

 override def compute(
      partition: Partition, context: TaskContext): Iterator[MapWithStateRDDRecord[K, S, E]] = {

    val stateRDDPartition = partition.asInstanceOf[MapWithStateRDDPartition]
    val prevStateRDDIterator = prevStateRDD.iterator(
      stateRDDPartition.previousSessionRDDPartition, context)
    val dataIterator = partitionedDataRDD.iterator(
      stateRDDPartition.partitionedDataRDDPartition, context)

    val prevRecord = if (prevStateRDDIterator.hasNext) Some(prevStateRDDIterator.next()) else None
    val newRecord = MapWithStateRDDRecord.updateRecordWithData(
      prevRecord,
      dataIterator,
      mappingFunction,
      batchTime,
      timeoutThresholdTime,
      removeTimedoutData = doFullScan // remove timedout data only when full scan is enabled
    )
    Iterator(newRecord)
  }

拿到prevStateRDD和parentRDD对应分区的迭代器,接着获取了prevStateRDD的一条数据,这个分区也只有一条MapWithStateRDDRecord类型的数据,维护了对应分区所有数据状态,接着调用了最核心的方法来跟新状态,最后返回了只包含一条数据的迭代器,我们来看看是怎么这个核心的计算逻辑:

 def updateRecordWithData[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
    prevRecord: Option[MapWithStateRDDRecord[K, S, E]],
    dataIterator: Iterator[(K, V)],
    mappingFunction: (Time, K, Option[V], State[S]) => Option[E],
    batchTime: Time,
    timeoutThresholdTime: Option[Long],
    removeTimedoutData: Boolean
  ): MapWithStateRDDRecord[K, S, E] = {
    // Create a new state map by cloning the previous one (if it exists) or by creating an empty one
    val newStateMap = prevRecord.map { _.stateMap.copy() }. getOrElse { new EmptyStateMap[K, S]() }

    val mappedData = new ArrayBuffer[E]
    val wrappedState = new StateImpl[S]()

    // Call the mapping function on each record in the data iterator, and accordingly
    // update the states touched, and collect the data returned by the mapping function
    dataIterator.foreach { case (key, value) =>
      wrappedState.wrap(newStateMap.get(key))
      val returned = mappingFunction(batchTime, key, Some(value), wrappedState)
      if (wrappedState.isRemoved) {
        newStateMap.remove(key)
      } else if (wrappedState.isUpdated
          || (wrappedState.exists && timeoutThresholdTime.isDefined)) {
        newStateMap.put(key, wrappedState.get(), batchTime.milliseconds)
      }
      mappedData ++= returned
    }

    // Get the timed out state records, call the mapping function on each and collect the
    // data returned
    if (removeTimedoutData && timeoutThresholdTime.isDefined) {
      newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, state, _) =>
        wrappedState.wrapTimingOutState(state)
        val returned = mappingFunction(batchTime, key, None, wrappedState)
        mappedData ++= returned
        newStateMap.remove(key)
      }
    }

    MapWithStateRDDRecord(newStateMap, mappedData)
  }

先copy了原来的状态,接着定义了两个变量,mappedData是最终要返回的结果,wrappedState可以看成是对state的包装,添加了一些额外的方法。

接着遍历当前批次的数据,从状态中取出key对应的原来的state,并根据自定义的函数来对state进行跟新,这里涉及到state的remove&update&timeout来对newStateMap进行跟新操作,并将有跟新的状态加入到了mappedData中。

若有设置超时时间,则还会对超时了的key进行移除,也会加入到mappedData中,最终通过新的状态对象newStateMap和需返回的mappedData数组构建了MapWithStateRDDRecord对象来返回。

而在前面提到的MapWithStateDStreamImpl实例的compute方法中:

  override def compute(validTime: Time): Option[RDD[MappedType]] = {
    internalStream.getOrCompute(validTime).map { _.flatMap[MappedType] { _.mappedData } }
  }

调用的就是这个mappedData数据。

我们发现返回的都是有update的数据,若要获取所有的状态在mapWithState之后调用stateSnapshots即可。若要清除某个key的状态,可在自定义的方法中调用state.remove()。

总结
updateStateByKey底层是将preSateRDD和parentRDD进行co-group,然后对所有数据都将经过自定义的mapFun函数进行一次计算,即使当前batch只有一条数据也会进行这么复杂的计算,大大的降低了性能,并且计算时间会随着维护的状态的增加而增加。
mapWithstate底层是创建了一个MapWithStateRDD,存的数据是MapWithStateRDDRecord对象,一个Partition对应一个MapWithStateRDDRecord对象,该对象记录了对应Partition所有的状态,每次只会对当前batch有的数据进行跟新,而不会像updateStateByKey一样对所有数据计算。

2、传入更新函数和分区数

def updateStateByKey[S: ClassTag](
    updateFunc: (Seq[V], Option[S]) => Option[S],
    numPartitions: Int
  ): DStream[(K, S)] = ssc.withScope {
  updateStateByKey(updateFunc, defaultPartitioner(numPartitions))
}

3、传入更新函数和自定义分区

def updateStateByKey[S: ClassTag](
    updateFunc: (Seq[V], Option[S]) => Option[S],
    partitioner: Partitioner
  ): DStream[(K, S)] = ssc.withScope {
  val cleanedUpdateF = sparkContext.clean(updateFunc)
  val newUpdateFunc = (iterator: Iterator[(K, Seq[V], Option[S])]) => {
    iterator.flatMap(t => cleanedUpdateF(t._2, t._3).map(s => (t._1, s)))
  }
  updateStateByKey(newUpdateFunc, partitioner, true)
}

4、传入完整的状态更新函数
前面的函数传入的都是不完整的更新函数,只是针对一个key的,他们在执行的时候也会生成一个完整的状态更新函数。
Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)] 入参是一个迭代器,参数1是key,参数2是这个key在这个batch中更新的值的集合,参数3是当前状态,最终得到key–>newvalue

def updateStateByKey[S: ClassTag](
    updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)],
    partitioner: Partitioner,
    rememberPartitioner: Boolean
  ): DStream[(K, S)] = ssc.withScope {
   new StateDStream(self, ssc.sc.clean(updateFunc), partitioner, rememberPartitioner, None)
}

例如,对于wordcount:

val newUpdateFunc = (iterator: Iterator[(String, Seq[Int], Option[Int])]) => {
  iterator.flatMap(t => function1(t._2, t._3).map(s => (t._1, s)))
}

5、加入初始状态

 initialRDD: RDD[(K, S)] 初始状态集合
def updateStateByKey[S: ClassTag](
    updateFunc: (Seq[V], Option[S]) => Option[S],
    partitioner: Partitioner,
    initialRDD: RDD[(K, S)]
  ): DStream[(K, S)] = ssc.withScope {
  val cleanedUpdateF = sparkContext.clean(updateFunc)
  val newUpdateFunc = (iterator: Iterator[(K, Seq[V], Option[S])]) => {
    iterator.flatMap(t => cleanedUpdateF(t._2, t._3).map(s => (t._1, s)))
  }
  updateStateByKey(newUpdateFunc, partitioner, true, initialRDD)
}

6、是否记得当前的分区

def updateStateByKey[S: ClassTag](
    updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)],
    partitioner: Partitioner,
    rememberPartitioner: Boolean,
    initialRDD: RDD[(K, S)]
  ): DStream[(K, S)] = ssc.withScope {
   new StateDStream(self, ssc.sc.clean(updateFunc), partitioner,
     rememberPartitioner, Some(initialRDD))
}

完整的例子:

def testUpdate={
    val sc = SparkUtils.getSpark("test", "db01").sparkContext
    val ssc = new StreamingContext(sc, Seconds(5))
    ssc.checkpoint("hdfs://ns1/config/checkpoint")
    val initialRDD = sc.parallelize(List(("hello", 1), ("world", 1)))
    val lines = ssc.fileStream[LongWritable,Text,TextInputFormat]("hdfs://ns1/config/data/")
    val words = lines.flatMap(x=>x._2.toString.split(","))
    val wordDstream :DStream[(String, Int)]= words.map(x => (x, 1))
    val result=wordDstream.reduceByKey(_ + _)

    def function1(newValues: Seq[Int], runningCount: Option[Int]): Option[Int] = {
      val newCount = newValues.sum + runningCount.getOrElse(0) // add the new values with the previous running count to get the new count
      Some(newCount)
    }
    val newUpdateFunc = (iterator: Iterator[(String, Seq[Int], Option[Int])]) => {
      iterator.flatMap(t => function1(t._2, t._3).map(s => (t._1, s)))
    }
    val stateDS=result.updateStateByKey(newUpdateFunc,new HashPartitioner (sc.defaultParallelism),true,initialRDD)
    stateDS.print()
    ssc.start()
    ssc.awaitTermination()
  }
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值