线性回归

package com.cfl.sparkmllib.lr;
import java.util.List;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.regression.LinearRegressionModel;
import org.apache.spark.mllib.regression.LinearRegressionWithSGD;
import scala.Tuple2;

/**
 * 线性回归
 * @author chenfenli
 *
 */
public class LinearRegression {

	@SuppressWarnings("resource")
	public static void main(String[] args) {
		SparkConf sparkConf = new SparkConf();
		sparkConf.setAppName("LinearRegression");
		sparkConf.setMaster("local");
		JavaSparkContext sparkContext = new JavaSparkContext(sparkConf);
		
		// 一、读取样本数据
		JavaRDD<String> data = sparkContext.textFile("resources/lpsa.data");
		JavaRDD<LabeledPoint> examples = data.map(new Function<String, LabeledPoint>() {
			private static final long serialVersionUID = 1L;
			public LabeledPoint call(String line) throws Exception {
				// -0.4307829,-1.63735562648104 -2.00621178480549 -1.86242597251066 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
				String[] yxs = line.split(",");
				String[] xStrs = yxs[1].split(" ");
				Double y = Double.valueOf(yxs[0]);		// 根据特征向量,求出来的y值
				double[] xs=new double[xStrs.length];	// 特征向量
				for(int i = 0; i < xs.length; i++) {
					xs[i] = Double.parseDouble(xStrs[i]);
				}
				Vector vector = Vectors.dense(xs);
				return new LabeledPoint(y,  vector);
			}
		});
		
		// 二、设置算法
		// 1、创建线性回归算法
		LinearRegressionWithSGD lrs = new LinearRegressionWithSGD();
		// 2、让训练出来的模型有w0参数,就是有截距
		lrs.setIntercept(true);
		// 3、设置步长 在每次迭代的过程中,梯度下降算法的下降步长大小
		lrs.optimizer().setStepSize(0.5);
		// 4、设置迭代次数:训练一个多元线性回归模型收敛(停止迭代)条件:error值小于用户指定的error值、达到一定的迭代次数
		lrs.optimizer().setNumIterations(60);
		// 5、每一次下山后,是否计算所有测试样本的误差值,1:所有样本,默认就是1.0
		lrs.optimizer().setMiniBatchFraction(1);
		
		
		// 三、训练算法->得到模型
		// 设置样本训练集和测试集的比例 第一个参:训练集比例 第二参数:测试集比例
		double[] weights = {0.8,0.2};
		JavaRDD<LabeledPoint>[] train2TestData = examples.randomSplit(weights,1);
		LinearRegressionModel model = lrs.run(JavaRDD.toRDD(train2TestData[0]));
		System.out.println(model.weights());	// 模型的特征权重
		System.out.println(model.intercept());	// 模型的截距
		
		
		// 四、对样本进行预测
		// 1、获取数据样板的测试集
		JavaRDD<LabeledPoint> testRdds = train2TestData[1];
		// 2、获取到测试集中的特征向量
		JavaRDD<Vector> testResultRdd = testRdds.map(new Function<LabeledPoint, Vector>() {
			private static final long serialVersionUID = 1L;
			public Vector call(LabeledPoint lp) throws Exception {
				return lp.features();
			}
		});
		// 3、通过模型和特征向量预测得到y值 
		JavaRDD<Double> prediction = model.predict(testResultRdd);
		
		// 五、求平均误差
		calculation(testRdds, prediction);
		
		sparkContext.stop();
	}
	
	public static void calculation(JavaRDD<LabeledPoint> testRdd, JavaRDD<Double> prediction) {
		// 5、求误差
		// 1、获取测试集中的真实y值
		JavaRDD<Double> testResultRdd2 = testRdd.map(new Function<LabeledPoint, Double>() {
			private static final long serialVersionUID = 1L;
			public Double call(LabeledPoint lp) throws Exception {
				return lp.label();
			}
		});
		// 2、将真实y值和测试得到的y值压缩到一起,取前20条打印查看
		JavaPairRDD<Double, Double> predictionAndLabel = prediction.zip(testResultRdd2);
		List<Tuple2<Double, Double>> printPredicts = predictionAndLabel.take(20);
		System.out.println("预测值" + "\t" + "真实值");
		for(Tuple2<Double, Double> t : printPredicts) {
			System.out.println(t._1 + "\t" + t._2);
		}
		// 3、计算每个的误差
		JavaRDD<Double> item = predictionAndLabel.map(new Function<Tuple2<Double,Double>, Double>() {
			private static final long serialVersionUID = 1L;
			public Double call(Tuple2<Double, Double> arg0) throws Exception {
				return 	Math.abs(arg0._1 - arg0._2);
			}
		});
		// 4、每个误差求和
		Double loss = item.reduce(new Function2<Double, Double, Double>() {
			private static final long serialVersionUID = 1L;
			public Double call(Double arg0, Double arg1) throws Exception {
				return arg0+arg1;
			}
		});
		// 5、求平均误差
		System.out.println(loss/testRdd.count());
	}
}

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值