Spark-MLlib的快速使用之七(决策树-分类)

本文介绍了一个基于决策树的分类模型实现过程,包括数据加载、解析、模型训练及预测。使用Spark框架,通过Java语言实现决策树分类器,对自行车租赁需求数据集进行分类预测。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

(1)数据

1,2011-01-01,1,0,1,0,0,6,0,1,0.24,0.2879,0.81,0,3,13,16

2,2011-01-01,1,0,1,1,0,6,0,1,0.22,0.2727,0.8,0,8,32,40

3,2011-01-01,1,0,1,2,0,6,0,1,0.22,0.2727,0.8,0,5,27,32

含义

instant,dteday,season,yr,mnth,holiday,weekday,workingday,weathersit,temp,atemp,hum,windspeed,casual,registered,cnt

 

(2)代码

 

public class HWDecisionTreeClass {

//【3--15】 为向量

//【16】为特征

private static class ParsePoint implements Function<String, LabeledPoint> {

private static final Pattern SPACE = Pattern.compile(",");

 

@Override

public LabeledPoint call(String line) {

String[] parts = line.split(",");

double[] v = new double[parts.length - 3];

for (int i = 0; i < parts.length - 3; i++)

v[i] = Double.parseDouble(parts[i + 2]);

return new LabeledPoint(Double.parseDouble(parts[16]), Vectors.dense(v));

}

}

 

public static void main(String[] args) {

 

SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTreeClassificationExample").setMaster("local");

JavaSparkContext jsc = new JavaSparkContext(sparkConf);

 

// 加载与解析数据

String datapath = "hour.txt";

 

JavaRDD<String> lines = jsc.textFile(datapath);

JavaRDD<LabeledPoint> traindata = lines.map(new ParsePoint());

List<LabeledPoint> take = traindata.take(3);

for (LabeledPoint labeledPoint : take) {

System.out.println("----->" + labeledPoint.features());

System.out.println("----->" + labeledPoint.label());

}

// 70%的数据用于训练,30%的数据用于测试

JavaRDD<LabeledPoint>[] splits = traindata.randomSplit(new double[] { 0.9, 0.1 });

// 训练数据

JavaRDD<LabeledPoint> trainingData = splits[0];

// 测试数据

JavaRDD<LabeledPoint> testData = splits[1];

// 设置参数 ,空的categoricalFeaturesInfo表示所有功能都是连续的。

Integer numClasses = 1900;

Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();

String impurity = "gini";

Integer maxDepth = 20;

Integer maxBins = 32;

// 训练DecisionTree模型进行分类。

final DecisionTreeModel model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo,

impurity, maxDepth, maxBins);

// 使用模型进程预测,并和实际值比较

JavaPairRDD<Double, Double> predictionAndLabel =

testData.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {

@Override

public Tuple2<Double, Double> call(LabeledPoint p) {

return new Tuple2<Double, Double>(model.predict(p.features()), p.label());

}

});

System.out.println(predictionAndLabel.take(10));

Double testErr = 1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() {

@Override

public Boolean call(Tuple2<Double, Double> pl) {

return !pl._1().equals(pl._2());

}

}).count() / testData.count();

System.out.println("Test Error: -------------------------------------------------------------------" + testErr);

System.out.println("Learned classification tree model:\n-------------------------------------------"

+ model.toDebugString());

}

}

 

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值