java mllib 算法_決策樹算法原理及Spark MLlib調用實例(Scala/Java/python)

本文介绍了决策树算法的基本原理及其在机器学习中的应用,特别是作为Spark MLlib库的一部分。讨论了决策树的不纯度度量(如基尼不纯度和熵),以及在Spark中实现的参数,如maxBins和maxDepth。文章通过Scala、Java和Python代码示例展示了如何使用Spark训练和评估决策树模型。

決策樹

算法介紹:

決策樹以及其集成算法是機器學習分類和回歸問題中非常流行的算法。因其易解釋性、可處理類別特征、易擴展到多分類問題、不需特征縮放等性質被廣泛使用。樹集成算法如隨機森林以及boosting算法幾乎是解決分類和回歸問題中表現最優的算法。

決策樹是一個貪心算法遞歸地將特征空間划分為兩個部分,在同一個葉子節點的數據最后會擁有同樣的標簽。每次划分通過貪心的以獲得最大信息增益為目的,從可選擇的分裂方式中選擇最佳的分裂節點。節點不純度有節點所含類別的同質性來衡量。工具提供為分類提供兩種不純度衡量(基尼不純度和熵),為回歸提供一種不純度衡量(方差)。

spark.ml支持二分類、多分類以及回歸的決策樹算法,適用於連續特征以及類別特征。另外,對於分類問題,工具可以返回屬於每種類別的概率(類別條件概率),對於回歸問題工具可以返回預測在偏置樣本上的方差。

參數:

checkpointInterval:

類型:整數型。

含義:設置檢查點間隔(>=1),或不設置檢查點(-1)。

featuresCol:

類型:字符串型。

含義:特征列名。

impurity:

類型:字符串型。

含義:計算信息增益的准則(不區分大小寫)。

labelCol:

類型:字符串型。

含義:標簽列名。

maxBins:

類型:整數型。

含義:連續特征離散化的最大數量,以及選擇每個節點分裂特征的方式。

maxDepth:

類型:整數型。

含義:樹的最大深度(>=0)。

minInfoGain:

類型:雙精度型。

含義:分裂節點時所需最小信息增益。

minInstancesPerNode:

類型:整數型。

含義:分裂后自節點最少包含的實例數量。

predictionCol:

類型:字符串型。

含義:預測結果列名。

probabilityCol:

類型:字符串型。

含義:類別條件概率預測結果列名。

rawPredictionCol:

類型:字符串型。

含義:原始預測。

seed:

類型:長整型。

含義:隨機種子。

thresholds:

類型:雙精度數組型。

含義:多分類預測的閥值,以調整預測結果在各個類別的概率。

示例:

下面的例子導入LibSVM格式數據,並將之划分為訓練數據和測試數據。使用第一部分數據進行訓練,剩下數據來測試。訓練之前我們使用了兩種數據預處理方法來對特征進行轉換,並且添加了元數據到DataFrame。

Scala:

import org.apache.spark.ml.Pipeline

import org.apache.spark.ml.classification.DecisionTreeClassificationModel

import org.apache.spark.ml.classification.DecisionTreeClassifier

import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator

import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}

// Load the data stored in LIBSVM format as a DataFrame.

val data = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")

// Index labels, adding metadata to the label column.

// Fit on whole dataset to include all labels in index.

val labelIndexer = new StringIndexer()

.setInputCol("label")

.setOutputCol("indexedLabel")

.fit(data)

// Automatically identify categorical features, and index them.

val featureIndexer = new VectorIndexer()

.setInputCol("features")

.setOutputCol("indexedFeatures")

.setMaxCategories(4) // features with > 4 distinct values are treated as continuous.

.fit(data)

// Split the data into training and test sets (30% held out for testing).

val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))

// Train a DecisionTree model.

val dt = new DecisionTreeClassifier()

.setLabelCol("indexedLabel")

.setFeaturesCol("indexedFeatures")

// Convert indexed labels back to original labels.

val labelConverter = new IndexToString()

.setInputCol("prediction")

.setOutputCol("predictedLabel")

.setLabels(labelIndexer.labels)

// Chain indexers and tree in a Pipeline.

val pipeline = new Pipeline()

.setStages(Array(labelIndexer, featureIndexer, dt, labelConverter))

// Train model. This also runs the indexers.

val model = pipeline.fit(trainingData)

// Make predictions.

val predictions = model.transform(testData)

// Select example rows to display.

predictions.select("predictedLabel", "label", "features").show(5)

// Select (prediction, true label) and compute test error.

val evaluator = new MulticlassClassificationEvaluator()

.setLabelCol("indexedLabel")

.setPredictionCol("prediction")

.setMetricName("accuracy")

val accuracy = evaluator.evaluate(predictions)

println("Test Error = " + (1.0 - accuracy))

val treeModel = model.stages(2).asInstanceOf[DecisionTreeClassificationModel]

println("Learned classification tree model:\n" + treeModel.toDebugString)

Java:

import org.apache.spark.ml.Pipeline;

import org.apache.spark.ml.PipelineModel;

import org.apache.spark.ml.PipelineStage;

import org.apache.spark.ml.classification.DecisionTreeClassifier;

import org.apache.spark.ml.classification.DecisionTreeClassificationModel;

import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;

import org.apache.spark.ml.feature.*;

import org.apache.spark.sql.Dataset;

import org.apache.spark.sql.Row;

import org.apache.spark.sql.SparkSession;

// Load the data stored in LIBSVM format as a DataFrame.

Dataset data = spark

.read()

.format("libsvm")

.load("data/mllib/sample_libsvm_data.txt");

// Index labels, adding metadata to the label column.

// Fit on whole dataset to include all labels in index.

StringIndexerModel labelIndexer = new StringIndexer()

.setInputCol("label")

.setOutputCol("indexedLabel")

.fit(data);

// Automatically identify categorical features, and index them.

VectorIndexerModel featureIndexer = new VectorIndexer()

.setInputCol("features")

.setOutputCol("indexedFeatures")

.setMaxCategories(4) // features with > 4 distinct values are treated as continuous.

.fit(data);

// Split the data into training and test sets (30% held out for testing).

Dataset[] splits = data.randomSplit(new double[]{0.7, 0.3});

Dataset trainingData = splits[0];

Dataset testData = splits[1];

// Train a DecisionTree model.

DecisionTreeClassifier dt = new DecisionTreeClassifier()

.setLabelCol("indexedLabel")

.setFeaturesCol("indexedFeatures");

// Convert indexed labels back to original labels.

IndexToString labelConverter = new IndexToString()

.setInputCol("prediction")

.setOutputCol("predictedLabel")

.setLabels(labelIndexer.labels());

// Chain indexers and tree in a Pipeline.

Pipeline pipeline = new Pipeline()

.setStages(new PipelineStage[]{labelIndexer, featureIndexer, dt, labelConverter});

// Train model. This also runs the indexers.

PipelineModel model = pipeline.fit(trainingData);

// Make predictions.

Dataset predictions = model.transform(testData);

// Select example rows to display.

predictions.select("predictedLabel", "label", "features").show(5);

// Select (prediction, true label) and compute test error.

MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()

.setLabelCol("indexedLabel")

.setPredictionCol("prediction")

.setMetricName("accuracy");

double accuracy = evaluator.evaluate(predictions);

System.out.println("Test Error = " + (1.0 - accuracy));

DecisionTreeClassificationModel treeModel =

(DecisionTreeClassificationModel) (model.stages()[2]);

System.out.println("Learned classification tree model:\n" + treeModel.toDebugString());

Python:

from pyspark.ml import Pipeline

from pyspark.ml.classification import DecisionTreeClassifier

from pyspark.ml.feature import StringIndexer, VectorIndexer

from pyspark.ml.evaluation import MulticlassClassificationEvaluator

# Load the data stored in LIBSVM format as a DataFrame.

data = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")

# Index labels, adding metadata to the label column.

# Fit on whole dataset to include all labels in index.

labelIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel").fit(data)

# Automatically identify categorical features, and index them.

# We specify maxCategories so features with > 4 distinct values are treated as continuous.

featureIndexer =\

VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data)

# Split the data into training and test sets (30% held out for testing)

(trainingData, testData) = data.randomSplit([0.7, 0.3])

# Train a DecisionTree model.

dt = DecisionTreeClassifier(labelCol="indexedLabel", featuresCol="indexedFeatures")

# Chain indexers and tree in a Pipeline

pipeline = Pipeline(stages=[labelIndexer, featureIndexer, dt])

# Train model. This also runs the indexers.

model = pipeline.fit(trainingData)

# Make predictions.

predictions = model.transform(testData)

# Select example rows to display.

predictions.select("prediction", "indexedLabel", "features").show(5)

# Select (prediction, true label) and compute test error

evaluator = MulticlassClassificationEvaluator(

labelCol="indexedLabel", predictionCol="prediction", metricName="accuracy")

accuracy = evaluator.evaluate(predictions)

print("Test Error = %g " % (1.0 - accuracy))

treeModel = model.stages[2]

# summary only

print(treeModel)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值