Spark求数据的统计值:最大值、最小值、平均值、方差、数量(Java代码)

Spark计算统计值

用到的核心类

org.apache.spark.mllib.stat.Statistics
org.apache.spark.mllib.stat.MultivariateStatisticalSummary

Java代码

package ml.summary;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.stat.MultivariateStatisticalSummary;
import org.apache.spark.mllib.stat.Statistics;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;

/**
 * @author wendao
 * @since 5/28/20 4:29 PM
 */
public class TestSummaryPure {
    public static void main(String[] args) {
        SparkConf conf = new SparkConf().setAppName("test").setMaster("local");
        JavaSparkContext sc = new JavaSparkContext(conf);
        SQLContext sqlContext = new SQLContext(sc);

        List<Row> data = Arrays.asList(
                RowFactory.create(0, Vectors.dense(1.0, 0.1, -1.0)),
                RowFactory.create(1, Vectors.dense(2.0, 1.1, 1.0)),
                RowFactory.create(2, Vectors.dense(3.0, 10.1, 3.0))
        );
        StructType schema = new StructType(new StructField[]{
                new StructField("id", DataTypes.IntegerType, false, Metadata.empty()),
                new StructField("features", new VectorUDT(), false, Metadata.empty())
        });
        Dataset<Row> dataFrame = sqlContext.createDataFrame(data, schema);
        JavaRDD<Vector> vectorJavaRDD = dataFrame.toJavaRDD().mapPartitions(new FlatMapFunction<Iterator<Row>, Vector>() {
            public Iterator<Vector> call(Iterator<Row> iterator) throws Exception {

                List<Vector> results = new ArrayList<Vector>();

                while (iterator.hasNext()) {
                    Row row = iterator.next();
                    Vector vector = (Vector) row.getAs("features");

                    results.add(new org.apache.spark.mllib.linalg.DenseVector(vector.toArray()));
                }

                return results.iterator();
            }
        });

        //  summary
        MultivariateStatisticalSummary summary = Statistics.colStats(vectorJavaRDD.rdd());
        Vector max = summary.max();
        System.out.println("max: "+ Arrays.toString(max.toArray()));
        System.out.println("min: "+Arrays.toString(summary.min().toArray()));
        System.out.println("count:" +summary.count());
        System.out.println("mean: "+Arrays.toString(summary.mean().toArray()));
        System.out.println("var: "+Arrays.toString(summary.variance().toArray()));
    }
}

代码里使用的是mllib.* 包里的类,如果想使用ml.* 也是可以的,注意数据格式转换。
但需要注意:Statistics.colStats() 传入参数内的Vector必须是 mllib 的。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值