本人是做java的,但是因为scala对spark的友好,所以因为好奇并且为了学习,再写代码之前还是决定使用scala来做
先贴核心的样本率处理
/**
* 正样本率多参递归处理
* @param spark sparksession
* @param sqlAllDF 主表(每次处理的结果数据)
* @param tableName 目标表名
* @param col 参数长度-1(标识 控制跳出递归)
* @param plus 正负样本标识
* @param column 类别字段(多个)
* @return
*/
def fac(spark:SparkSession,sqlAllDF:DataFrame
,tableName:String,col:Int,plus:String,column:Seq[String]): DataFrame ={
import spark.sql
if (column.length != 0 && col >= 0){
val c = column(col)
val sqlDF = sql(s"SELECT $c cs, sum($plus)/count(1) result from $tableName GROUP BY $c")
val reDF = sqlAllDF.join(sqlDF, sqlAllDF(c) === sqlDF("cs"),"left")
val df = reDF.withColumn(column(col),reDF("result"))
.drop("cs")
.drop("result")
val cc = col - 1
fac(spark,df,tableName,cc,plus,column)
}else{
sqlAllDF
}
}
然后是spark的调用
/**
* 正样本率函数
* @param spark
* @param tableName 目标表名
* @param plus 正负样本标识
* @param tableName_result 输出表名
* @param column 类别字段(多个)
*/
def rateing(spark:SparkSession,tableName:String, plus:String, tableName_result:String, column:Seq[String]): Unit ={
val records = "records"
import spark.sql
val sqlAllDF = sql(s"SELECT * from $tableName ")
val df = fac(spark,sqlAllDF,tableName,column.length-1,plus,column)
df.createOrReplaceTempView(records)
sql(s"CREATE TABLE $tableName_result like $records")
sql(s"insert overwrite table $tableName_result select * from $records")
sql(s"SELECT * FROM $tableName_result ").show()
spark.stop()
}
其实到这里 已经结束了,为了让代码全一点,贴下入口
/**
* 特征入口
*/
def main(args: Array[String]): Unit = {
val warehouseLocation = "file:${system:user.dir}/spark-warehouse"
val spark = SparkSession
.builder
.appName("SampleRate")
.config("spark.sql.warehouse.dir", warehouseLocation)
.master("local")
.enableHiveSupport()
.getOrCreate()
rateing(spark,"test_multiple_col","status","test_multiple_col_result",Seq("name","os","app"))
}
举个例子,有表如图
执行之后(结果替换原来的字段,这是业务需求)
刚刚接触scala,欢迎大家指教
<spark.version>2.1.0</spark.version> <scala.version>2.11.8</scala.version>