机器学习实战-决策树 java版代码开发实现

这篇博客分享了如何用Java实现机器学习中的决策树算法,包括ID3算法,主要内容涉及二值特征判别、数据集加载、召回率计算、决策树构建、存储和读取。提供python版本和其他机器学习算法的咨询。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

话不多说,直接上代码,若有帮助,帮忙点赞哦
python版,或其他机器学习算法,可发邮箱:476562571@qq.com

在这里插入图片描述

主要实现功能:
特征 二值判别
递归遍历文件目录加载训练数据集
召回率计算
决策树构建
决策树存储(存储json文件)需要依赖 com.alibab fastjson-1.2.7.jar
决策树读取(读取json文件)需要依赖 com.alibab fastjson-1.2.7.jar

package com.code.ku.qa.metion.classifier;

import com.alibaba.fastjson.JSONObject;
import com.code.ku.qa.metion.Metion;
import org.apache.commons.io.FileUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.File;
import java.io.IOException;
import java.io.Serializable;
import java.util.*;

/**
 * @Date: 2018/11/15
 * @Time: 19:08
 * @User: Likf
 * @Description:
 */
public class DecisionTreeID3 {


    /** Logger */
    private static final Logger _LOG = LoggerFactory.getLogger(DecisionTreeID3.class);

    public static TreeNode tree = null;

    static{
        tree = loadTreeFromJsonFile(Metion.Config.getPath("classify\\id3\\tree.json"));
    }

    public DecisionTreeID3() {

    }

    public static String classify(List<String> labels,List<String> testData){
        return classify(tree,labels,testData);
    }



    /**
     * 计算香农熵
     * @param dataset
     */
    public double calChannonEnt(List<List<String>> dataset){

        Map<String,Double> outLabels = new HashMap<>();
        for (List<String> fetures:dataset){
            String outLabel = fetures.get(fetures.size()-1);
            if(!outLabels.keySet().contains(outLabel)){
                outLabels.put(outLabel,0.0);
            }
            outLabels.put(outLabel,outLabels.get(outLabel)+1);
        }

        double channonEnt = 0.0;

        for(Map.Entry<String,Double> entry:outLabels.entrySet()){
            double pi = entry.getValue()/dataset.size();
            channonEnt -= pi*(Math.log(pi)/Math.log(2.0));
        }
        return channonEnt;

    }


    /**
     * 划分数据集
     * @param dataset
     * @param fetureIndex
     * @param value
     * @return
     */
    private List<List<String>> splitDataSet(List<List<String>> dataset,int fetureIndex,String value){

        List<List<String>> subDataSet = new ArrayList<>();


        for(List<String> fetures:dataset){
            try {
                if(fetures.get(fetureIndex).equals(value)){
                    List<String> reduceFetures = new LinkedList<>();
                    reduceFetures.addAll(fetures.subList(0,fetureIndex));
                    reduceFetures.addAll(fetures.subList(fetureIndex+1,fetures.size()));
                    subDataSet.add(reduceFetures);
                }
            } catch (Exception e) {
                _LOG.trace("异常特征:"+fetures);
            }
        }
        return subDataSet;
    }


    /**
     * 选取信息增益最大的特征划分数据集
     * @param dataSet
     * @return
     */
    private int chooseBestFetureToSplit(List<List<String>> dataSet){


        int numFetures = dataSet.get(0).size()-1;
        double baseEntropy = calChannonEnt(dataSet);
        double bestInfoGain = 0.0;
        int bestFeture = -1;

        for (int i = 0; i < numFetures; i++) {
            double infoGain = 0.0;
            double newEntropy = 0.0;
            Set<String> featureVals = getFetureVals(dataSet,i);
            for(String fetureVal:featureVals){
                List<List<String>> subDataSet = splitDataSet(dataSet,i,fetureVal);
                double prob = (double)subDataSet.size()/dataSet.size();
                newEntropy+=prob*calChannonEnt(subDataSet);
            }
            infoGain=baseEntropy-newEntropy;
            if(infoGain>bestInfoGain){
                bestInfoGain = infoGain;
                bestFeture = i;
            }
        }
        return bestFeture;
    }

    /**
     * 投票选取分类
     * @param classifyList
     */
    public String majorityCnt(List<String> classifyList){
        Map<String,Double> classCount = new HashMap<>();

        classifyList.forEach(
                classify->{
                    if(classCount.get(classify)==null){
                        classCount.put(classify,0.0);
                    }else{
                        classCount.put(classify,classCount.get(classify)+1);
                    }
                }
        );

        double max = 0.0;
        String key = null;
        for(Map.Entry<String,Double> entry:classCount.entrySet()){
            if(entry.getValue()>max){
                max = entry.getValue();
                key = entry.getKey();
            }
        }

        return key;

    }

    /**
     * 创建决策树
     * @param dataset
     * @param labels
     */
    public TreeNode createTree(List<List<String>> dataset,List<String> labels){

        List<String> classifyList = getFetureLists(dataset,dataset.get(0).size()-1);
        Set<String> classifySet = new HashSet<>(classifyList);
        if(classifySet.size() == 1){
            return new TreeNode(classifyList.get(0));
        }

        if(dataset.get(0).size() == 1){
            return new TreeNode(majorityCnt(classifyList));
        }
        int bestFeat = chooseBestFetureToSplit(dataset);
        String bestFeatLabel = null;

        if(bestFeat == -1){
            bestFeat =labels.size()-1;
        }
        bestFeatLabel = labels.get(bestFeat);
        TreeNode tree = new TreeNode(bestFeatLabel);//tree.addChild(bestFeatLabel);

        List<String> subLabels = new ArrayList<>();
        for(int i=0;i<labels.size();i++){
            if(i!=bestFeat){
                subLabels.add(labels.get(i));;
            }
        }
        //labels.remove(bestFeat);
        Set<String> fetureVals = getFetureVals(dataset,bestFeat);

        for(String fv:fetureVals){
          //  tree.addLabel(fv);
            TreeNode node = tree.addChild(createTree(splitDataSet(dataset,bestFeat,fv),subLabels));
            node.setLabel(fv);
            node.setValue(fv);

        }
        return tree;
    }


    /**
     * 新的特征进行分类判别
     * @param tree
     * @param featLabels
     * @param testData
     * @return
     */
    public static String classify(TreeNode tree,List<String> featLabels,List<String> testData){

        Map<String,Integer> featLabelMap = new HashMap<>();
        for (int i = 0; i <featLabels.size() ; i++) {
            featLabelMap.put(featLabels.get(i),i);
        }

        int featIndex = featLabelMap.get(tree.name);
        String classifyLabel = null;
        for(TreeNode node:tree.childs){
            if(node.value.equals(testData.get(featIndex))){
                if(node.childs.isEmpty()){
                    classifyLabel = node.name;
                }else{
                    classifyLabel = classify(node,featLabels,testData);
                }
            }
        }
        if(classifyLabel == null){
            //toDO 决策树中未发现的节点
        }
        return classifyLabel;

    }

    /**
     * 计算召回率
     * @param tree
     * @param labels
     * @param testDataSet
     * @return
     */
    public double recallRate(TreeNode tree,List<String> labels,List<List<String>> testDataSet){
        int flagPos = testDataSet.get(0).size()-1;
        double count = 0.0;
        for(List<String> fetures:testDataSet){

            String realVal = fetures.get(flagPos);
            String preVal = classify(tree,labels,fetures.subList(0,flagPos));
            if(realVal.equals(preVal)){
                count++;
            }
        }

        return count/testDataSet.size();
    }

    /**
     * 获取所有去重后特征对应的值
     * @param dataSet
     * @param i
     * @return
     */
    private Set<String> getFetureVals(List<List<String>> dataSet, int i) {
        Set<String> vals = new LinkedHashSet<>(getFetureLists(dataSet,i));
        return vals;
    }

    /**
     * 获取所有特征对应的值
     * @param dataSet
     * @param i
     * @return
     */
    private List<String> getFetureLists(List<List<String>> dataSet, int i) {
        List<String> vals = new LinkedList<>();
        for(List<String> fetures:dataSet){
            try {
                vals.add(fetures.get(i));
            } catch (ArrayIndexOutOfBoundsException e) {
                //e.printStackTrace();
                _LOG.error("异常特征:"+fetures.toString());
            }
        }
        return vals;
    }

    /**
     * 将决策树存储在json,文件中,需要引入 ali fast-json.jar包
     * @param treePath
     * @return
     */
    public void storeDTreeToJson(String treePath,List<String> labels){

        List<List<String>> trainDataSet = loadDataSet(treePath);
        DecisionTreeID3.TreeNode tree = createTree(trainDataSet,labels);
        File treeFile = new File(treePath);
        try {
            if(!treeFile.exists()){
                treeFile.createNewFile();
            }
            FileUtils.writeStringToFile(treeFile,JSONObject.toJSONString(tree),"utf-8");
        } catch (IOException e) {
            e.printStackTrace();
        }

    }

    /**
     * 从json文件中加载决策树,文件中,需要引入 ali fast-json.jar包
     * @param treePath
     * @return
     */
    public static TreeNode loadTreeFromJsonFile(String treePath){

        try {
            File treeFile = new File(treePath);
            String jsonText = FileUtils.readFileToString(treeFile,"utf-8");
            return JSONObject.parseObject(jsonText,DecisionTreeID3.TreeNode.class);
        } catch (IOException e) {
            e.printStackTrace();
            _LOG.error("加载失败。。。");
            return null;
        }
    }

    /**
     * 测试数据,方便测试
     * @return
     */
    public List<List<String>> createDataset(){
        List<List<String>> dataset = new LinkedList<>();
        dataset.add(Arrays.asList("1","1","yes"));
        dataset.add(Arrays.asList("1","1","yes"));
        dataset.add(Arrays.asList("1","0","no"));
        dataset.add(Arrays.asList("0","1","no"));
        dataset.add(Arrays.asList("0","1","no"));
        return dataset;
    }

    public List<String> createLabels(){
        List<String> labels = new ArrayList<>();
        labels.add("surfacing");
        labels.add("flippers");
        return labels;
    }

    /**
     * 打印树结构
     * @param tree
     * @param root
     */
    public void print(TreeNode tree,String root){

        System.out.println(root+" "+tree.value+":"+tree.name);
        if(tree.childs==null || tree.childs.isEmpty()){
            return;
        }

        for(TreeNode node:tree.childs){
            print(node,root+root);
        }
    }

    /**
     * 树结构
     */
    public static class TreeNode implements Serializable {

        private String name;

        private String label;

/*        private List<String> labels = new ArrayList<>();*/

        private List<TreeNode> childs = new ArrayList<>();

        private String value;

        public TreeNode(){

        }
        public TreeNode(String name) {
            this.name = name;
        }



       /* public void addLabel(String label){
            labels.add(label);
        }*/
        public TreeNode addChild(TreeNode node){
            childs.add(node);
            return node;
        }
        public TreeNode addChild(String name){
            return addChild(new TreeNode(name));
        }

        public String getValue() {
            return value;
        }

        public void setValue(String value) {
            this.value = value;
        }

        @Override
        public String toString() {
            return name;
        }

        public String getName() {
            return name;
        }

        public void setName(String name) {
            this.name = name;
        }



        public List<TreeNode> getChilds() {
            return childs;
        }

        public void setChilds(List<TreeNode> childs) {
            this.childs = childs;
        }

        public String getLabel() {
            return label;
        }

        public void setLabel(String label) {
            this.label = label;
        }
    }

    /**
     * 加载数据
     * @param dirPath
     * @return
     */
    public List<List<String>> loadDataSet(String dirPath){

        File dir = new File(dirPath);
        List<File> dsFiles = new ArrayList<>();
        travelDir(dir,dsFiles);
        List<List<String>> dataset = new ArrayList<>();
        for(File file:dsFiles){
            try {
                List<String> lines = FileUtils.readLines(file,"utf-8");
                for(String line:lines){
                    String[] lineArr = line.split(",");
                    dataset.add(Arrays.asList(lineArr));
                }
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
        return dataset;
    }

    /**
     * 遍历文件夹
     * @param dir
     * @param dsFiles
     * @return
     */
    public List<File> travelDir(File dir,List<File> dsFiles){
        File[] files = dir.listFiles();
        for(File file : files){
            if(file.isDirectory()){
                travelDir(file,dsFiles);
            }else{
                dsFiles.add(file);
            }
        }
        return dsFiles;
    }
    public static void main(String[] args) {
        DecisionTreeID3 id3 = new DecisionTreeID3();
        List<List<String>> dataset = id3.createDataset();

        _LOG.trace("香农熵:"+id3.calChannonEnt(dataset));

        _LOG.trace("切分:"+id3.splitDataSet(dataset,0,"1"));

        _LOG.trace("选择最好的特征分类:"+id3.chooseBestFetureToSplit(dataset));


        TreeNode tree = id3.createTree(dataset,id3.createLabels());

        id3.print(tree,"->");

        List<String> testData1 = Arrays.asList("1","0");
        List<String> testData2 = Arrays.asList("1","1");
        List<String> testData3 = Arrays.asList("0","1");
        System.out.println(testData1+":"+DecisionTreeID3.classify(tree,id3.createLabels(),testData1));
        System.out.println(testData2+":"+DecisionTreeID3.classify(tree,id3.createLabels(),testData2));
        System.out.println(testData3+":"+DecisionTreeID3.classify(tree,id3.createLabels(),testData3));

        String dir = "D:\\workspace\\zl\\pre\\mist-parent\\mist-kbqa\\datas\\metion\\train\\train-10";
        System.out.println(id3.loadDataSet(dir));
    }
}

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值