Spark的RDD的aggregate() 函数

aggregate() 函数的返回类型不需要和 RDD 中的元素类型一致,所以在使用时,需要提供所期待的返回类型的初始值,然后通过一个函数把 RDD 中的元素累加起来放入累加器。

考虑到每个结点都是在本地进行累加的,所以最终还需要提供第二个函数来将累加器两两合并。

aggregate(zero)(seqOp,combOp) 函数首先使用 seqOp 操作聚合各分区中的元素,然后再使用 combOp 操作把所有分区的聚合结果再次聚合,两个操作的初始值都是 zero。

seqOp 的操作是遍历分区中的所有元素 T,第一个 T 跟 zero 做操作,结果再作为与第二个 T 做操作的 zero,直到遍历完整个分区。

combOp 操作是把各分区聚合的结果再聚合。aggregate() 函数会返回一个跟 RDD 不同类型的值。因此,需要 seqOp 操作来把分区中的元素 T 合并成一个 U,以及 combOp 操作把所有 U 聚合。

下面举一个利用 aggreated() 函数求平均数的例子。

val rdd = List(1,2,3,4)
    val input = sc.parallelize(rdd)
//    (x,y) => (x[0]+y,x[1]+1)
//      def:x => (x+1);
//    (acc,value) => (acc._1 + value,acc._2 + 1)
    val result = input.aggregate((0,0))(//初始值
       //累加器 (元组累加元组结果,RDD单个元素值)=>(元组累加结果+RDD单个元素,元组累加计数+1)
      //acc元组:(acc._1,acc._2)  (10,1)
  //acc._1:累加结果
  //acc._2:计数
      (acc,value) => (acc._1 + value,acc._2 + 1),
        //combine 合并函数 合并元组累加结果
      (acc1,acc2) => (acc1._1 + acc2._1,acc1._2 + acc2._2)
    )
    val avg = result._1/result._2

源码:

 def aggregate[U: ClassTag](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U = withScope {
    // Clone the zero value since we will also be serializing it as part of tasks
    var jobResult = Utils.clone(zeroValue, sc.env.serializer.newInstance())
    val cleanSeqOp = sc.clean(seqOp)
    val cleanCombOp = sc.clean(combOp)
    val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)
    val mergeResult = (index: Int, taskResult: U) => jobResult = combOp(jobResult, taskResult)
    sc.runJob(this, aggregatePartition, mergeResult)
    jobResult
  }

程序的详细过程大概如下。

定义一个初始值 (0,0),即所期待的返回类型的初始值。代码 (acc,value) => (acc._1 + value,acc._2 + 1) 中的 value 是函数定义里面的 T,这里是 List 里面的元素。acc._1 + value,acc._2 + 1 的过程如下。

(0+1,0+1)→(1+2,1+1)→(3+3,2+1)→(6+4,3+1),结果为(10,4)。

实际的 Spark 执行过程是分布式计算,可能会把 List 分成多个分区,假如是两个:p1(1,2) 和 p2(3,4)。

经过计算,各分区的结果分别为 (3,2) 和 (7,2)。这样,执行 (acc1,acc2) => (acc1._1 + acc2._1,acc1._2 + acc2._2) 的结果就是 (3+7,2+2),即 (10,4),然后可计算平均值。

 

首先,Spark文档中aggregate函数定义如下

def aggregate[U](zeroValue: U)(seqOp: (U, T) ⇒ U, combOp: (U, U) ⇒ U)(implicit arg0: ClassTag[U]): U
Aggregate the elements of each partition, and then the results for all the partitions, using given combine functions and a neutral "zero value". This function can return a different result type, U, than the type of this RDD, T. Thus, we need one operation for merging a T into an U and one operation for merging two U's, as in scala.TraversableOnce. Both of these functions are allowed to modify and return their first argument instead of creating a new U to avoid memory allocation.   seqOp操作会聚合各分区中的元素,然后combOp操作把所有分区的聚合结果再次聚合,两个操作的初始值都是zeroValue.   seqOp的操作是遍历分区中的所有元素(T),第一个T跟zeroValue做操作,结果再作为与第二个T做操作的zeroValue,直到遍历完整个分区。combOp操作是把各分区聚合的结果,再聚合。aggregate函数返回一个跟RDD不同类型的值。因此,需要一个操作seqOp来把分区中的元素T合并成一个U,另外一个操作combOp把所有U聚合。

zeroValue
the initial value for the accumulated result of each partition for the seqOp operator, and also the initial value for the combine results from different partitions for the combOp operator - this will typically be the neutral element (e.g. Nil for list concatenation or 0 for summation)

seqOp
an operator used to accumulate results within a partition

combOp
an associative operator used to combine results from different partitions

举个例子。假如List(1,2,3,4,5,6,7,8,9,10),对List求平均数,使用aggregate可以这样操作。
C:\Windows\System32>scala
Welcome to Scala 2.11.8 (Java HotSpot(TM) Client VM, Java 1.8.0_91).
Type in expressions for evaluation. Or try :help.

scala> val rdd = List(1,2,3,4,5,6,7,8,9)
rdd: List[Int] = List(1, 2, 3, 4, 5, 6, 7, 8, 9)

scala> rdd.par.aggregate((0,0))(

(acc,number) => (acc._1 + number, acc._2 + 1),

(par1,par2) => (par1._1 + par2._1, par1._2 + par2._2)

)
res0: (Int, Int) = (45,9)

scala> res0._1 / res0._2
res1: Int = 5

过程大概这样:

首先,初始值是(0,0),这个值在后面2步会用到。

然后,(acc,number) => (acc._1 + number, acc._2 + 1),number即是函数定义中的T,这里即是List中的元素。所以acc._1 + number, acc._2 + 1的过程如下。

1.   0+1,  0+1

2.  1+2,  1+1

3.  3+3,  2+1

4.  6+4,  3+1

5.  10+5,  4+1

6.  15+6,  5+1

7.  21+7,  6+1

8.  28+8,  7+1

9.  36+9,  8+1

结果即是(45,9)。这里演示的是单线程计算过程,实际Spark执行中是分布式计算,可能会把List分成多个分区,假如3个,p1(1,2,3,4),p2(5,6,7,8),p3(9),经过计算各分区的的结果(10,4),(26,4),(9,1),这样,执行(par1,par2) => (par1._1 + par2._1, par1._2 + par2._2)就是(10+26+9,4+4+1)即(45,9).再求平均值就简单了。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值