java决策树_java编写ID3决策树

这篇博客介绍了如何用Java实现ID3决策树算法。文章提供了Data和DecisionTree两个核心类的源码,用于构建和操作决策树。Data类表示样本数据,DecisionTree类包含构建决策树的主要方法,如createTree()、chooseBestFeatureToSplit()和calcShannonEnt()。还展示了如何构造测试数据集并生成决策树。

说明:每个样本都会装入Data样本对象,决策树生成算法接收的是一个Array样本列表,所以构建测试数据时也要符合格式,最后生成的决策树是树的根节点,通过里面提供的showTree()方法可查看整个树结构,下面奉上源码。

Data.java

package ai.tree.data;

import java.util.HashMap;

/**

* 样本类

* @author ChenLuyang

* @date 2019/2/21

*/

public class Data implements Cloneable{

/**

* K是特征描述,V是特征值

*/

private HashMap feature = new HashMap();

/**

* 该样本结论

*/

private String result;

public Data(HashMap feature,String result){

this.feature = feature;

this.result = result;

}

public HashMap getFeature() {

return feature;

}

public String getResult() {

return result;

}

private void setFeature(HashMap feature) {

this.feature = feature;

}

@Override

public Data clone()

{

Data object=null;

try {

object = (Data) super.clone();

object.setFeature((HashMap) this.feature.clone());

} catch (CloneNotSupportedException e) {

e.printStackTrace();

}

return object;

}

}

DecisionTree.java

package ai.tree.algorithm;

import ai.tree.data.Data;

import java.math.BigDecimal;

import java.util.*;

/**

* @author ChenLuyang

* @date 2019/2/21

*/

public class DecisionTree {

/**

* 递归构建决策树

*

* @param dataList 样本集合

* @return ai.tree.algorithm.DecisionTree.TreeNode 使用传入样本构建的决策节点

* @author ChenLuyang

* @date 2019/2/21 16:05

*/

public TreeNode createTree(List dataList) {

//创建当前节点

TreeNode nowTreeNode = new TreeNode();

//当前节点的各个分支节点

Map featureDecisionMap = new HashMap();

//统计当前样本集中所有的分类结果

Set resultSet = new HashSet();

for (Data data :

dataList) {

resultSet.add(data.getResult());

}

//如果当前样本集只有一种类别,则表示不用分类了,返回当前节点

if (resultSet.size() == 1) {

String resultClassify = resultSet.iterator().next();

nowTreeNode.setResultNode(resultClassify);

return nowTreeNode;

}

//如果数据集中特征为空,则选择整个集合中出现次数最多的分类,作为分类结果

if (dataList.get(0).getFeature().size() == 0) {

Map countMap = new HashMap();

for (Data data :

dataList) {

Integer num = countMap.get(data.getResult());

if (num == null) {

countMap.put(data.getResult(), 1);

} else {

countMap.put(data.getResult(), num + 1);

}

}

String tmpResult = "";

Integer tmpNum = 0;

for (String res :

countMap.keySet()) {

if (countMap.get(res) > tmpNum) {

tmpNum = countMap.get(res);

tmpResult = res;

}

}

nowTreeNode.setResultNode(tmpResult);

return nowTreeNode;

}

//寻找当前最优分类

String bestLabel = chooseBestFeatureToSplit(dataList);

//提取最优特征的所有可能值

Set bestLabelInfoSet = new HashSet();

for (Data data :

dataList) {

bestLabelInfoSet.add(data.getFeature().get(bestLabel));

}

//使用最优特征的各个特征值进行分类

for (String labelInfo :

bestLabelInfoSet) {

for (Data data :

dataList) {

}

List branchDataList = splitDataList(dataList, bestLabel, labelInfo);

//最优特征下该特征值的节点

TreeNode branchTreeNode = createTree(branchDataList);

featureDecisionMap.put(labelInfo, branchTreeNode);

}

nowTreeNode.setDecisionNode(bestLabel, featureDecisionMap);

return nowTreeNode;

}

/**

* 计算传入数据集中的最优分类特征

*

* @param dataList

* @return int 最优分类特征的描述

* @author ChenLuyang

* @date 2019/2/21 14:12

*/

public String chooseBestFeatureToSplit(List dataList) {

//目前数据集中的特征集合

Set futureSet = dataList.get(0).getFeature().keySet();

//未分类时的熵

BigDecimal baseEntropy = calcShannonEnt(dataList);

//熵差

BigDecimal bestInfoGain = new BigDecimal("0");

//最优特征

String bestFeature = "";

//按照各特征分类

for (String future :

futureSet) {

//该特征分类后的熵

BigDecimal futureEntropy = new BigDecimal("0");

//该特征的所有特征值去重集合

Set futureInfoSet = new HashSet();

for (Data data :

dataList) {

futureInfoSet.add(data.getFeature().get(future));

}

//按照该特征的特征值一一分类

for (String futureInfo :

futureInfoSet) {

List splitResultDataList = splitDataList(dataList, future, futureInfo);

//分类后样本数占总样本数的比例

BigDecimal tmpProb = new BigDecimal(splitResultDataList.size() + "").divide(new BigDecimal(dataList.size() + ""), 5, BigDecimal.ROUND_HALF_DOWN);

//所占比例乘以分类后的样本熵,然后再进行熵的累加

futureEntropy = futureEntropy.add(tmpProb.multiply(calcShannonEnt(splitResultDataList)));

}

BigDecimal subEntropy = baseEntropy.subtract(futureEntropy);

if (subEntropy.compareTo(bestInfoGain) >= 0) {

bestInfoGain = subEntropy;

bestFeature = future;

}

}

return bestFeature;

}

/**

* 计算传入样本集的熵值

*

* @param dataList 样本集

* @return java.math.BigDecimal 熵

* @author ChenLuyang

* @date 2019/2/22 9:41

*/

public BigDecimal calcShannonEnt(List dataList) {

//样本总数

BigDecimal sumEntries = new BigDecimal(dataList.size() + "");

//香农熵

BigDecimal shannonEnt = new BigDecimal("0");

//统计各个分类结果的样本数量

Map resultCountMap = new HashMap();

for (Data data :

dataList) {

Integer dataResultCount = resultCountMap.get(data.getResult());

if (dataResultCount == null) {

resultCountMap.put(data.getResult(), 1);

} else {

resultCountMap.put(data.getResult(), dataResultCount + 1);

}

}

for (String resultCountKey :

resultCountMap.keySet()) {

BigDecimal resultCountValue = new BigDecimal(resultCountMap.get(resultCountKey).toString());

BigDecimal prob = resultCountValue.divide(sumEntries, 5, BigDecimal.ROUND_HALF_DOWN);

shannonEnt = shannonEnt.subtract(prob.multiply(new BigDecimal(Math.log(prob.doubleValue()) / Math.log(2) + "")));

}

return shannonEnt;

}

/**

* 根据某个特征的特征值,进行样本数据的划分,将划分后的样本数据集返回

*

* @param dataList 待划分的样本数据集

* @param future 筛选的特征依据

* @param info 筛选的特征值依据

* @return java.util.List 按照指定特征值分类后的数据集

* @author ChenLuyang

* @date 2019/2/21 18:26

*/

public List splitDataList(List dataList, String future, String info) {

List resultDataList = new ArrayList();

for (Data data :

dataList) {

if (data.getFeature().get(future).equals(info)) {

Data newData = (Data) data.clone();

newData.getFeature().remove(future);

resultDataList.add(newData);

}

}

return resultDataList;

}

/**

* L:每一个特征的描述信息的类型

* F:特征的类型

* R:最终分类结果的类型

*/

public class TreeNode {

/**

* 该节点的最优特征的描述信息

*/

private L label;

/**

* 根据不同的特征作出响应的决定。

* K为特征值,V为该特征值作出的决策节点

*/

private Map featureDecisionMap;

/**

* 是否为最终分类节点

*/

private boolean isFinal;

/**

* 最终分类结果信息

*/

private R resultClassify;

/**

* 设置叶子节点

*

* @param resultClassify 最终分类结果

* @return void

* @author ChenLuyang

* @date 2019/2/22 18:31

*/

public void setResultNode(R resultClassify) {

this.isFinal = true;

this.resultClassify = resultClassify;

}

/**

* 设置分支节点

*

* @param label 当前分支节点的描述信息(特征)

* @param featureDecisionMap 当前分支节点的各个特征值,与其对应的子节点

* @return void

* @author ChenLuyang

* @date 2019/2/22 18:31

*/

public void setDecisionNode(L label, Map featureDecisionMap) {

this.isFinal = false;

this.label = label;

this.featureDecisionMap = featureDecisionMap;

}

/**

* 展示当前节点的树结构

*

* @return void

* @author ChenLuyang

* @date 2019/2/22 16:54

*/

public String showTree() {

HashMap treeMap = new HashMap();

if (isFinal) {

String key = "result";

R value = resultClassify;

treeMap.put(key, value.toString());

} else {

String key = label.toString();

HashMap showFutureMap = new HashMap();

for (F f :

featureDecisionMap.keySet()) {

showFutureMap.put(f, featureDecisionMap.get(f).showTree());

}

String value = showFutureMap.toString();

treeMap.put(key, value);

}

return treeMap.toString();

}

public L getLabel() {

return label;

}

public Map getFeatureDecisionMap() {

return featureDecisionMap;

}

public R getResultClassify() {

return resultClassify;

}

public boolean getFinal() {

return isFinal;

}

}

}

Start.java

package ai.tree.algorithm;

import ai.tree.data.Data;

import java.util.ArrayList;

import java.util.HashMap;

import java.util.List;

/**

* @author ChenLuyang

* @date 2019/2/22

*/

public class Start {

/**

* 构建测试样本集,测试样本如下:

样本特征:{头发长短=短发, 身材=胖, 是否戴眼镜=有眼镜} 分类:男

样本特征:{头发长短=长发, 身材=瘦, 是否戴眼镜=有眼镜} 分类:女

样本特征:{头发长短=短发, 身材=胖, 是否戴眼镜=有眼镜} 分类:女

样本特征:{头发长短=长发, 身材=胖, 是否戴眼镜=没眼镜} 分类:男

样本特征:{头发长短=短发, 身材=瘦, 是否戴眼镜=没眼镜} 分类:男

样本特征:{头发长短=长发, 身材=瘦, 是否戴眼镜=有眼镜} 分类:女

样本特征:{头发长短=长发, 身材=胖, 是否戴眼镜=有眼镜} 分类:男

* @author ChenLuyang

* @date 2019/2/21 15:34

* @return java.util.List 样本集

*/

public static List createDataList(){

/**

* 样本特征描述

* @author ChenLuyang

* @date 2019/2/22 18:55

* @return java.util.List

*/

String[] labels = new String[]{"是否戴眼镜", "头发长短", "身材"};

List dataList = new ArrayList();

HashMap feature1 = new HashMap();

feature1.put(labels[0],"有眼镜");

feature1.put(labels[1].toString(),"短发");

feature1.put(labels[2].toString(),"胖");

dataList.add(new Data(feature1,"男"));

HashMap feature2 = new HashMap();

feature2.put(labels[0],"有眼镜");

feature2.put(labels[1],"长发");

feature2.put(labels[2],"瘦");

dataList.add(new Data(feature2,"女"));

HashMap feature3 = new HashMap();

feature3.put(labels[0],"有眼镜");

feature3.put(labels[1],"短发");

feature3.put(labels[2],"胖");

dataList.add(new Data(feature3,"女"));

HashMap feature4 = new HashMap();

feature4.put(labels[0],"没眼镜");

feature4.put(labels[1],"长发");

feature4.put(labels[2],"胖");

dataList.add(new Data(feature4,"男"));

HashMap feature5 = new HashMap();

feature5.put(labels[0],"没眼镜");

feature5.put(labels[1],"短发");

feature5.put(labels[2],"瘦");

dataList.add(new Data(feature5,"男"));

HashMap feature6 = new HashMap();

feature6.put(labels[0],"有眼镜");

feature6.put(labels[1],"长发");

feature6.put(labels[2],"瘦");

dataList.add(new Data(feature6,"女"));

HashMap feature7 = new HashMap();

feature7.put(labels[0],"有眼镜");

feature7.put(labels[1],"长发");

feature7.put(labels[2],"胖");

dataList.add(new Data(feature7,"男"));

return dataList;

}

public static void main(String[] args) {

DecisionTree decisionTree = new DecisionTree();

//使用测试样本生成决策树

DecisionTree.TreeNode tree = decisionTree.createTree(createDataList());

//展示决策树

System.out.println(tree.showTree());

}

}

生成树结构:{是否戴眼镜={没眼镜={result=男}, 有眼镜={身材={胖={头发长短={长发={result=男}, 短发={result=女}}}, 瘦={result=女}}}}}

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符  | 博主筛选后可见
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值