import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types._
import org.apache.spark.ml.feature.Word2Vec
import org.apache.spark.sql.functions._
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.Row
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.sql.expressions.UserDefinedFunction // 导入 UserDefinedFunction
import org.apache.spark.sql.expressions.Window
val spark = SparkSession.builder()
.appName("w2v0218")
.config("spark.driver.memory", "32g")
.config("spark.executor.memory", "28g")
.config("spark.driver.maxResultSize", "16g")
.config("spark.sql.shuffle.partitions", "2000")
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
.config("spark.dynamicAllocation.enabled", "true")
.getOrCreate()
// 执行SQL查询
val query = """select device,pkg_list package_list from database.table"""
println(query)
val queryDf = spark.sql(query)
queryDf.show(5)
// 训练 Word2Vec 模型
val word2Vec = new Word2Vec()
.setVectorSize(256)
.setMinCount(40000)
.setWindowSize(10)
.setMaxIter(1)
.setNumPartitions(2000)
.setInputCol("package_list")
.setOutputCol("features")
val model = word2Vec.fit(queryDf)
println("fit ok")
// 转换并保存结果(修复变量名)
model.transform(queryDf)
.select("device", "features")
.write
.mode("overwrite")
.option("compression", "snappy")
.saveAsTable("database.bx_v256_mc4000")
spark.stop()