sparkSql自定义UDF、UDAF、UDTF、Shuffle partition数量

本文探讨了SparkSQL在读取测试数据时,如何操作DataFrame以观察分区数变化,并通过自定义User Defined Aggregate Function (UDAF)实现聚合性能优化。详细介绍了如何使用UDF和UDAF处理JSON数据,以及不同场景下分区数和结果展示。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

测试数据
{"name":"zhangsan", "age":20}
{"name":"lisi", "age":20}
{"name":"wangwu", "age":20}
{"name":"wangwu", "age":30}
{"name":"wangwu", "age":35}
spark sql 分区数测试代码
package sparkSql

import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.{SparkConf, SparkContext, sql}
import org.junit.{After, Before, Test}

class ReadJson {
	val conf: SparkConf = new SparkConf().setAppName("sparkSql").setMaster("local[3]")
	val spark: SparkSession = SparkSession.builder().config(conf).getOrCreate()
	var outpath: String = "out"

	import util.MyPredef._

	@Before
	def init() {
		outpath.delete()
	}

	@After
	def after() {
		spark.stop()
	}

	/**
	 * 1、构建sparkSession 对象
	 * 2、读取数据
	 * 3、输出或展示
	 */
	@Test
	def queryJson(): Unit = {

		val df: DataFrame = spark.read.json("in/user.json")
		df.show()
	}

	@Test
	def dfToRDD(): Unit = {
		val df: DataFrame = spark.read.json("in/user.json")
		df.createOrReplaceTempView("user")

		val noShuffledRDD: DataFrame = spark.sql("select * from user")
		println("noShuffledRDD 分区数:" + noShuffledRDD.rdd.getNumPartitions)
		noShuffledRDD.show()


		val ShuffledRDD: DataFrame = spark.sql("select name, count(1) from user group by name")
		println("ShuffledRDD 分区数:" + ShuffledRDD.rdd.getNumPartitions)
		ShuffledRDD.show()
	}
}

spark sql 分区数输出结果

noShuffledRDD 分区数:1
+---+--------+
|age|    name|
+---+--------+
| 20|zhangsan|
| 20|    lisi|
| 20|  wangwu|
| 30|  wangwu|
| 35|  wangwu|
+---+--------+

ShuffledRDD 分区数:200
+--------+--------+
|    name|count(1)|
+--------+--------+
|  wangwu|       3|
|zhangsan|       1|
|    lisi|       1|
+--------+--------+
自定义 UDAF 测试代码
	package sparkSql

import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StructType}
import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Encoders, Row, SparkSession, TypedColumn}
import org.junit.{After, Before, Test}

class sparkSqlFunction {
	val conf: SparkConf = new SparkConf().setAppName("sparkSql").setMaster("local[3]")

	var outpath: String = "out"

	val spark: SparkSession = SparkSession.builder().config(conf).getOrCreate()

	import util.MyPredef._

	@Before
	def init() {
		outpath.delete()
	}

	@After
	def after() {
		spark.stop()
	}

	/**
	 * 0、自定义 UDF
	 */
	@Test
	def udfTest(): Unit = {
		val df: DataFrame = spark.read.json("in/user.json")
		df.createOrReplaceTempView("user")

		spark.udf.register("add_name", (name: String) => name + " UDF")

		val resDF: DataFrame = spark.sql("select add_name(name), * from user")

		resDF.show()

	}

	/**
	 * 1、自定义 UDAF
	 */
	@Test
	def udafTest(): Unit = {
		val df: DataFrame = spark.read.json("in/user.json")
		df.createOrReplaceTempView("user")

		val udaf: AvgUDAFFunction = new AvgUDAFFunction

		spark.udf.register("my_avg", udaf)

		val resDF: DataFrame = spark.sql("select name, my_avg(age) from user group by name")

		resDF.show()

	}

	/**
	 * 1、自定义强类型 UDAF
	 */
	@Test
	def udafClassTest(): Unit = {
		val df: DataFrame = spark.read.json("in/user.json")
		df.createOrReplaceTempView("user")

		val udaf = new AvgClassUDAFFunction

		val age: TypedColumn[UserBean, Double] = udaf.toColumn.name("avg_age")

		import spark.implicits._
		val ds: Dataset[UserBean] = df.as[UserBean]

		ds.select(age).show()
	}
}


/**
 * 弱类型 UDAF
 */
class AvgUDAFFunction extends UserDefinedAggregateFunction {

	// 输入的数据结构
	override def inputSchema: StructType = {
		new StructType().add("age", LongType)
	}

	// 计算时的数据结构
	override def bufferSchema: StructType = {
		new StructType().add("sum", LongType).add("count", LongType)
	}

	// udaf 返回的数据类型
	override def dataType: DataType = DoubleType

	// 函数是否稳定
	override def deterministic: Boolean = true

	// 缓冲区的初始化
	override def initialize(buffer: MutableAggregationBuffer): Unit = {
		buffer(0) = 0L
		buffer(1) = 0L
	}

	//
	override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
		// sum
		buffer(0) = buffer.getLong(0) + input.getLong(0)
		// count
		buffer(1) = buffer.getLong(1) + 1
	}

	// 多个节点缓冲区的合并
	override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
		// sum
		buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
		// count
		buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
	}

	// 计算
	override def evaluate(buffer: Row): Any = {
		(buffer.getLong(0) / buffer.getLong(1)) toDouble
	}
}

/**
 * 强类型 UDAF
 */
case class UserBean(name: String, age: Long)
case class AvgBuffer(var sum: Int, var count: Int)

class AvgClassUDAFFunction extends Aggregator[UserBean, AvgBuffer, Double] {
	// 初始化
	override def zero: AvgBuffer = AvgBuffer(0, 0)

	// 分区内聚合
	override def reduce(b: AvgBuffer, a: UserBean): AvgBuffer = {
		b.count = b.count + 1
		b.sum = (b.sum + a.age).toInt
		b
	}

	// 分区间合并
	override def merge(b1: AvgBuffer, b2: AvgBuffer): AvgBuffer = {
		b1.sum = b1.sum + b2.sum
		b1.count = b1.count + b2.count
		b1
	}

	// 完成计算
	override def finish(reduction: AvgBuffer): Double = {
		reduction.sum.toDouble / reduction.count
	}

	//
	override def bufferEncoder: Encoder[AvgBuffer] = Encoders.product

	//
	override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}

自定义 UDAF 测试结果

udfTest:匿名类


+------------------+---+--------+
|UDF:add_name(name)|age|    name|
+------------------+---+--------+
|      zhangsan UDF| 20|zhangsan|
|          lisi UDF| 20|    lisi|
|        wangwu UDF| 20|  wangwu|
|        wangwu UDF| 30|  wangwu|
|        wangwu UDF| 35|  wangwu|
+------------------+---+--------+

udafTest:

+--------+--------------------+
|    name|avgudaffunction(age)|
+--------+--------------------+
|  wangwu|                28.0|
|zhangsan|                20.0|
|    lisi|                20.0|
+--------+--------------------+

udafClassTest:

+-------+
|avg_age|
+-------+
|   25.0|
+-------+
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值