关于spark的mllib学习总结(Java版)

本篇博客主要讲述如何利用 Spark 的mliib构建 机器学习 模型并预测新的数据,具体的流程如下图所示: 

基本流程

加载数据

对于数据的加载或保存,mllib提供了MLUtils包,其作用是Helper methods to load,save and pre-process data used in MLLib.博客中的数据是采用spark中提供的数据sample_libsvm_data.txt,其有一百个数据样本,658个特征。具体的数据形式如图所示: 
数据格式

加载libsvm

JavaRDD<LabeledPoint> lpdata = MLUtils.loadLibSVMFile(sc, this.libsvmFile).toJavaRDD();
    
  • 1
  • 1

LabeledPoint数据类型是对应与libsvmfile格式文件, 具体格式为: 
Lable(double类型),vector(Vector类型)

转化dataFrame数据类型

JavaRDD<Row> jrow = lpdata.map(new LabeledPointToRow());
StructType schema = new StructType(new StructField[]{
                    new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
                    new StructField("features", new VectorUDT(), false, Metadata.empty()),
        });
SQLContext jsql = new SQLContext(sc);
DataFrame df = jsql.createDataFrame(jrow, schema);
    
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

DataFrame:DataFrame是一个以命名列方式组织的分布式数据集。在概念上,它跟关系型数据库中的一张表或者1个Python(或者R)中的data frame一样,但是比他们更优化。DataFrame可以根据结构化的数据文件、Hive表、外部数据库或者已经存在的RDD构造。

SQLContext:spark sql所有功能的入口是SQLContext类,或者SQLContext的子类。为了创建一个基本的SQLContext,需要一个SparkContext。

特征提取

特征归一化处理

StandardScaler scaler = new StandardScaler().setInputCol("features").setOutputCol("normFeatures").setWithStd(true);
DataFrame scalerDF = scaler.fit(df).transform(df);
scaler.save(this.scalerModelPath);
    
  • 1
  • 2
  • 3
  • 1
  • 2
  • 3

利用卡方统计做特征提取

ChiSqSelector selector = new ChiSqSelector().setNumTopFeatures(500).setFeaturesCol("normFeatures").setLabelCol("label").setOutputCol("selectedFeatures");
ChiSqSelectorModel chiModel = selector.fit(scalerDF);
DataFrame selectedDF = chiModel.transform(scalerDF).select("label", "selectedFeatures");
chiModel.save(this.featureSelectedModelPath);
    
  • 1
  • 2
  • 3
  • 4
  • 1
  • 2
  • 3
  • 4

训练机器学习模型(以SVM为例)

//转化为LabeledPoint数据类型, 训练模型
JavaRDD<Row> selectedrows = selectedDF.javaRDD();
JavaRDD<LabeledPoint> trainset = selectedrows.map(new RowToLabel());

//训练SVM模型, 并保存
int numIteration = 200;
SVMModel model = SVMWithSGD.train(trainset.rdd(), numIteration);
model.clearThreshold();
model.save(sc, this.mlModelPath);

// LabeledPoint数据类型转化为Row
static class LabeledPointToRow implements Function<LabeledPoint, Row> {

        public Row call(LabeledPoint p) throws Exception {
            double label = p.label();
            Vector vector = p.features();
            return RowFactory.create(label, vector);
        }
    }

//Rows数据类型转化为LabeledPoint
static class RowToLabel implements Function<Row, LabeledPoint> {

        public LabeledPoint call(Row r) throws Exception {
            Vector features = r.getAs(1);
            double label = r.getDouble(0);
            return new LabeledPoint(label, features);
        }
    }
    
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29

测试新的样本

测试新的样本前,需要将样本做数据的转化和特征提取的工作,所有刚刚训练模型的过程中,除了保存机器学习模型,还需要保存特征提取的中间模型。具体代码如下:

//初始化spark
SparkConf conf = new SparkConf().setAppName("SVM").setMaster("local");
conf.set("spark.testing.memory", "2147480000");
SparkContext sc = new SparkContext(conf);

//加载测试数据
JavaRDD<LabeledPoint> testData = MLUtils.loadLibSVMFile(sc, this.predictDataPath).toJavaRDD();

//转化DataFrame数据类型
JavaRDD<Row> jrow =testData.map(new LabeledPointToRow());
        StructType schema = new StructType(new StructField[]{
                    new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
                    new StructField("features", new VectorUDT(), false, Metadata.empty()),
        });
SQLContext jsql = new SQLContext(sc);
DataFrame df = jsql.createDataFrame(jrow, schema);

        //数据规范化
StandardScaler scaler = StandardScaler.load(this.scalerModelPath);
DataFrame scalerDF = scaler.fit(df).transform(df);

        //特征选取
ChiSqSelectorModel chiModel = ChiSqSelectorModel.load( this.featureSelectedModelPath);
DataFrame selectedDF = chiModel.transform(scalerDF).select("label", "selectedFeatures");
    
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24

测试数据集

SVMModel svmmodel = SVMModel.load(sc, this.mlModelPath);
JavaRDD<Tuple2<Double, Double>> predictResult = testset.map(new Prediction(svmmodel)) ;
predictResult.collect();

static class Prediction implements Function<LabeledPoint, Tuple2<Double , Double>> {
        SVMModel model;
        public Prediction(SVMModel model){
            this.model = model;
        }
        public Tuple2<Double, Double> call(LabeledPoint p) throws Exception {
            Double score = model.predict(p.features());
            return new Tuple2<Double , Double>(score, p.label());
        }
    }
    
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

计算准确率

double accuracy = predictResult.filter(new PredictAndScore()).count() * 1.0 / predictResult.count();
System.out.println(accuracy);

static class PredictAndScore implements Function<Tuple2<Double, Double>, Boolean> {
        public Boolean call(Tuple2<Double, Double> t) throws Exception {
            double score = t._1();
            double label = t._2();
            System.out.print("score:" + score + ", label:"+ label);
            if(score >= 0.0 && label >= 0.0) return true;
            else if(score < 0.0 && label < 0.0) return true;
            else return false;
        }
    }
    
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

具体的代码,放在我的github上:https://github.com/Quincy1994/MachineLearning/

3
0
 
 
我的同类文章

参考知识库

img
机器学习知识库

img
Python知识库

img
Hive知识库

img
人工智能机器学习知识库

img
软件测试知识库

img
MySQL知识库

img
Apache Spark知识库

猜你在找
8小时学会HTML网页开发
Android入门实战教程
Swift视频教程(第三季)
Swift视频教程(第七季)
Swift视频教程(第六季)
多层感知机MLP算法原理及Spark MLlib调用实例ScalaJavaPython
梯度迭代树GBDT算法原理及Spark MLlib调用实例ScalaJavapython
Pipeline详解及Spark MLlib使用示例ScalaJavaPython
spark MLlib 学习
spark mllib机器学习之五 LinearRegressionWithSGD
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值