package xyz.ixiaoban.bigdata.spark
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.{LabeledPoint, LinearRegressionWithSGD}
import org.apache.spark.{SparkConf, SparkContext}
object LinearRegression {
val conf = new SparkConf().setMaster("local").setAppName("LinearRegression3 ") //创建环境变
val sc = new SparkContext(conf) //创建环境变量实例
def main(args: Array[String]) {
val data = sc.textFile("file:\\c:\\A.txt") //获取数据集路径
val parsedData = data.map { line => //开始对数据集处理
val parts = line.split(',') //根据逗号进行分区
LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble)))
}.cache() //转化数据格式
val model = LinearRegressionWithSGD.train(parsedData, 100,0.1) //建立模型
val result = model.predict(Vectors.dense(4,5)) //通过模型预测模型
println(result) //打印预测结果
}
}