测试数据
{"name":"zhangsan", "age":20}
{"name":"lisi", "age":20}
{"name":"wangwu", "age":20}
{"name":"wangwu", "age":30}
{"name":"wangwu", "age":35}
spark sql 分区数测试代码
package sparkSql
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.{SparkConf, SparkContext, sql}
import org.junit.{After, Before, Test}
class ReadJson {
val conf: SparkConf = new SparkConf().setAppName("sparkSql").setMaster("local[3]")
val spark: SparkSession = SparkSession.builder().config(conf).getOrCreate()
var outpath: String = "out"
import util.MyPredef._
@Before
def init() {
outpath.delete()
}
@After
def after() {
spark.stop()
}
/**
* 1、构建sparkSession 对象
* 2、读取数据
* 3、输出或展示
*/
@Test
def queryJson(): Unit = {
val df: DataFrame = spark.read.json("in/user.json")
df.show()
}
@Test
def dfToRDD(): Unit = {
val df: DataFrame = spark.read.json("in/user.json")
df.createOrReplaceTempView("user")
val noShuffledRDD: DataFrame = spark.sql("select * from user")
println("noShuffledRDD 分区数:" + noShuffledRDD.rdd.getNumPartitions)
noShuffledRDD.show()
val ShuffledRDD: DataFrame = spark.sql("select name, count(1) from user group by name")
println("ShuffledRDD 分区数:" + ShuffledRDD.rdd.getNumPartitions)
ShuffledRDD.show()
}
}
spark sql 分区数输出结果
noShuffledRDD 分区数:1
+---+--------+
|age| name|
+---+--------+
| 20|zhangsan|
| 20| lisi|
| 20| wangwu|
| 30| wangwu|
| 35| wangwu|
+---+--------+
ShuffledRDD 分区数:200
+--------+--------+
| name|count(1)|
+--------+--------+
| wangwu| 3|
|zhangsan| 1|
| lisi| 1|
+--------+--------+
自定义 UDAF 测试代码
package sparkSql
import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StructType}
import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Encoders, Row, SparkSession, TypedColumn}
import org.junit.{After, Before, Test}
class sparkSqlFunction {
val conf: SparkConf = new SparkConf().setAppName("sparkSql").setMaster("local[3]")
var outpath: String = "out"
val spark: SparkSession = SparkSession.builder().config(conf).getOrCreate()
import util.MyPredef._
@Before
def init() {
outpath.delete()
}
@After
def after() {
spark.stop()
}
/**
* 0、自定义 UDF
*/
@Test
def udfTest(): Unit = {
val df: DataFrame = spark.read.json("in/user.json")
df.createOrReplaceTempView("user")
spark.udf.register("add_name", (name: String) => name + " UDF")
val resDF: DataFrame = spark.sql("select add_name(name), * from user")
resDF.show()
}
/**
* 1、自定义 UDAF
*/
@Test
def udafTest(): Unit = {
val df: DataFrame = spark.read.json("in/user.json")
df.createOrReplaceTempView("user")
val udaf: AvgUDAFFunction = new AvgUDAFFunction
spark.udf.register("my_avg", udaf)
val resDF: DataFrame = spark.sql("select name, my_avg(age) from user group by name")
resDF.show()
}
/**
* 1、自定义强类型 UDAF
*/
@Test
def udafClassTest(): Unit = {
val df: DataFrame = spark.read.json("in/user.json")
df.createOrReplaceTempView("user")
val udaf = new AvgClassUDAFFunction
val age: TypedColumn[UserBean, Double] = udaf.toColumn.name("avg_age")
import spark.implicits._
val ds: Dataset[UserBean] = df.as[UserBean]
ds.select(age).show()
}
}
/**
* 弱类型 UDAF
*/
class AvgUDAFFunction extends UserDefinedAggregateFunction {
// 输入的数据结构
override def inputSchema: StructType = {
new StructType().add("age", LongType)
}
// 计算时的数据结构
override def bufferSchema: StructType = {
new StructType().add("sum", LongType).add("count", LongType)
}
// udaf 返回的数据类型
override def dataType: DataType = DoubleType
// 函数是否稳定
override def deterministic: Boolean = true
// 缓冲区的初始化
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0L
buffer(1) = 0L
}
//
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
// sum
buffer(0) = buffer.getLong(0) + input.getLong(0)
// count
buffer(1) = buffer.getLong(1) + 1
}
// 多个节点缓冲区的合并
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
// sum
buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
// count
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
}
// 计算
override def evaluate(buffer: Row): Any = {
(buffer.getLong(0) / buffer.getLong(1)) toDouble
}
}
/**
* 强类型 UDAF
*/
case class UserBean(name: String, age: Long)
case class AvgBuffer(var sum: Int, var count: Int)
class AvgClassUDAFFunction extends Aggregator[UserBean, AvgBuffer, Double] {
// 初始化
override def zero: AvgBuffer = AvgBuffer(0, 0)
// 分区内聚合
override def reduce(b: AvgBuffer, a: UserBean): AvgBuffer = {
b.count = b.count + 1
b.sum = (b.sum + a.age).toInt
b
}
// 分区间合并
override def merge(b1: AvgBuffer, b2: AvgBuffer): AvgBuffer = {
b1.sum = b1.sum + b2.sum
b1.count = b1.count + b2.count
b1
}
// 完成计算
override def finish(reduction: AvgBuffer): Double = {
reduction.sum.toDouble / reduction.count
}
//
override def bufferEncoder: Encoder[AvgBuffer] = Encoders.product
//
override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
自定义 UDAF 测试结果
udfTest
:匿名类
+------------------+---+--------+
|UDF:add_name(name)|age| name|
+------------------+---+--------+
| zhangsan UDF| 20|zhangsan|
| lisi UDF| 20| lisi|
| wangwu UDF| 20| wangwu|
| wangwu UDF| 30| wangwu|
| wangwu UDF| 35| wangwu|
+------------------+---+--------+
udafTest
:
+--------+--------------------+
| name|avgudaffunction(age)|
+--------+--------------------+
| wangwu| 28.0|
|zhangsan| 20.0|
| lisi| 20.0|
+--------+--------------------+
udafClassTest
:
+-------+
|avg_age|
+-------+
| 25.0|
+-------+