话不多说,直接上代码,若有帮助,帮忙点赞哦
python版,或其他机器学习算法,可发邮箱:476562571@qq.com
主要实现功能:
特征 二值判别
递归遍历文件目录加载训练数据集
召回率计算
决策树构建
决策树存储(存储json文件)需要依赖 com.alibab fastjson-1.2.7.jar
决策树读取(读取json文件)需要依赖 com.alibab fastjson-1.2.7.jar
package com.code.ku.qa.metion.classifier;
import com.alibaba.fastjson.JSONObject;
import com.code.ku.qa.metion.Metion;
import org.apache.commons.io.FileUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.io.IOException;
import java.io.Serializable;
import java.util.*;
/**
* @Date: 2018/11/15
* @Time: 19:08
* @User: Likf
* @Description:
*/
public class DecisionTreeID3 {
/** Logger */
private static final Logger _LOG = LoggerFactory.getLogger(DecisionTreeID3.class);
public static TreeNode tree = null;
static{
tree = loadTreeFromJsonFile(Metion.Config.getPath("classify\\id3\\tree.json"));
}
public DecisionTreeID3() {
}
public static String classify(List<String> labels,List<String> testData){
return classify(tree,labels,testData);
}
/**
* 计算香农熵
* @param dataset
*/
public double calChannonEnt(List<List<String>> dataset){
Map<String,Double> outLabels = new HashMap<>();
for (List<String> fetures:dataset){
String outLabel = fetures.get(fetures.size()-1);
if(!outLabels.keySet().contains(outLabel)){
outLabels.put(outLabel,0.0);
}
outLabels.put(outLabel,outLabels.get(outLabel)+1);
}
double channonEnt = 0.0;
for(Map.Entry<String,Double> entry:outLabels.entrySet()){
double pi = entry.getValue()/dataset.size();
channonEnt -= pi*(Math.log(pi)/Math.log(2.0));
}
return channonEnt;
}
/**
* 划分数据集
* @param dataset
* @param fetureIndex
* @param value
* @return
*/
private List<List<String>> splitDataSet(List<List<String>> dataset,int fetureIndex,String value){
List<List<String>> subDataSet = new ArrayList<>();
for(List<String> fetures:dataset){
try {
if(fetures.get(fetureIndex).equals(value)){
List<String> reduceFetures = new LinkedList<>();
reduceFetures.addAll(fetures.subList(0,fetureIndex));
reduceFetures.addAll(fetures.subList(fetureIndex+1,fetures.size()));
subDataSet.add(reduceFetures);
}
} catch (Exception e) {
_LOG.trace("异常特征:"+fetures);
}
}
return subDataSet;
}
/**
* 选取信息增益最大的特征划分数据集
* @param dataSet
* @return
*/
private int chooseBestFetureToSplit(List<List<String>> dataSet){
int numFetures = dataSet.get(0).size()-1;
double baseEntropy = calChannonEnt(dataSet);
double bestInfoGain = 0.0;
int bestFeture = -1;
for (int i = 0; i < numFetures; i++) {
double infoGain = 0.0;
double newEntropy = 0.0;
Set<String> featureVals = getFetureVals(dataSet,i);
for(String fetureVal:featureVals){
List<List<String>> subDataSet = splitDataSet(dataSet,i,fetureVal);
double prob = (double)subDataSet.size()/dataSet.size();
newEntropy+=prob*calChannonEnt(subDataSet);
}
infoGain=baseEntropy-newEntropy;
if(infoGain>bestInfoGain){
bestInfoGain = infoGain;
bestFeture = i;
}
}
return bestFeture;
}
/**
* 投票选取分类
* @param classifyList
*/
public String majorityCnt(List<String> classifyList){
Map<String,Double> classCount = new HashMap<>();
classifyList.forEach(
classify->{
if(classCount.get(classify)==null){
classCount.put(classify,0.0);
}else{
classCount.put(classify,classCount.get(classify)+1);
}
}
);
double max = 0.0;
String key = null;
for(Map.Entry<String,Double> entry:classCount.entrySet()){
if(entry.getValue()>max){
max = entry.getValue();
key = entry.getKey();
}
}
return key;
}
/**
* 创建决策树
* @param dataset
* @param labels
*/
public TreeNode createTree(List<List<String>> dataset,List<String> labels){
List<String> classifyList = getFetureLists(dataset,dataset.get(0).size()-1);
Set<String> classifySet = new HashSet<>(classifyList);
if(classifySet.size() == 1){
return new TreeNode(classifyList.get(0));
}
if(dataset.get(0).size() == 1){
return new TreeNode(majorityCnt(classifyList));
}
int bestFeat = chooseBestFetureToSplit(dataset);
String bestFeatLabel = null;
if(bestFeat == -1){
bestFeat =labels.size()-1;
}
bestFeatLabel = labels.get(bestFeat);
TreeNode tree = new TreeNode(bestFeatLabel);//tree.addChild(bestFeatLabel);
List<String> subLabels = new ArrayList<>();
for(int i=0;i<labels.size();i++){
if(i!=bestFeat){
subLabels.add(labels.get(i));;
}
}
//labels.remove(bestFeat);
Set<String> fetureVals = getFetureVals(dataset,bestFeat);
for(String fv:fetureVals){
// tree.addLabel(fv);
TreeNode node = tree.addChild(createTree(splitDataSet(dataset,bestFeat,fv),subLabels));
node.setLabel(fv);
node.setValue(fv);
}
return tree;
}
/**
* 新的特征进行分类判别
* @param tree
* @param featLabels
* @param testData
* @return
*/
public static String classify(TreeNode tree,List<String> featLabels,List<String> testData){
Map<String,Integer> featLabelMap = new HashMap<>();
for (int i = 0; i <featLabels.size() ; i++) {
featLabelMap.put(featLabels.get(i),i);
}
int featIndex = featLabelMap.get(tree.name);
String classifyLabel = null;
for(TreeNode node:tree.childs){
if(node.value.equals(testData.get(featIndex))){
if(node.childs.isEmpty()){
classifyLabel = node.name;
}else{
classifyLabel = classify(node,featLabels,testData);
}
}
}
if(classifyLabel == null){
//toDO 决策树中未发现的节点
}
return classifyLabel;
}
/**
* 计算召回率
* @param tree
* @param labels
* @param testDataSet
* @return
*/
public double recallRate(TreeNode tree,List<String> labels,List<List<String>> testDataSet){
int flagPos = testDataSet.get(0).size()-1;
double count = 0.0;
for(List<String> fetures:testDataSet){
String realVal = fetures.get(flagPos);
String preVal = classify(tree,labels,fetures.subList(0,flagPos));
if(realVal.equals(preVal)){
count++;
}
}
return count/testDataSet.size();
}
/**
* 获取所有去重后特征对应的值
* @param dataSet
* @param i
* @return
*/
private Set<String> getFetureVals(List<List<String>> dataSet, int i) {
Set<String> vals = new LinkedHashSet<>(getFetureLists(dataSet,i));
return vals;
}
/**
* 获取所有特征对应的值
* @param dataSet
* @param i
* @return
*/
private List<String> getFetureLists(List<List<String>> dataSet, int i) {
List<String> vals = new LinkedList<>();
for(List<String> fetures:dataSet){
try {
vals.add(fetures.get(i));
} catch (ArrayIndexOutOfBoundsException e) {
//e.printStackTrace();
_LOG.error("异常特征:"+fetures.toString());
}
}
return vals;
}
/**
* 将决策树存储在json,文件中,需要引入 ali fast-json.jar包
* @param treePath
* @return
*/
public void storeDTreeToJson(String treePath,List<String> labels){
List<List<String>> trainDataSet = loadDataSet(treePath);
DecisionTreeID3.TreeNode tree = createTree(trainDataSet,labels);
File treeFile = new File(treePath);
try {
if(!treeFile.exists()){
treeFile.createNewFile();
}
FileUtils.writeStringToFile(treeFile,JSONObject.toJSONString(tree),"utf-8");
} catch (IOException e) {
e.printStackTrace();
}
}
/**
* 从json文件中加载决策树,文件中,需要引入 ali fast-json.jar包
* @param treePath
* @return
*/
public static TreeNode loadTreeFromJsonFile(String treePath){
try {
File treeFile = new File(treePath);
String jsonText = FileUtils.readFileToString(treeFile,"utf-8");
return JSONObject.parseObject(jsonText,DecisionTreeID3.TreeNode.class);
} catch (IOException e) {
e.printStackTrace();
_LOG.error("加载失败。。。");
return null;
}
}
/**
* 测试数据,方便测试
* @return
*/
public List<List<String>> createDataset(){
List<List<String>> dataset = new LinkedList<>();
dataset.add(Arrays.asList("1","1","yes"));
dataset.add(Arrays.asList("1","1","yes"));
dataset.add(Arrays.asList("1","0","no"));
dataset.add(Arrays.asList("0","1","no"));
dataset.add(Arrays.asList("0","1","no"));
return dataset;
}
public List<String> createLabels(){
List<String> labels = new ArrayList<>();
labels.add("surfacing");
labels.add("flippers");
return labels;
}
/**
* 打印树结构
* @param tree
* @param root
*/
public void print(TreeNode tree,String root){
System.out.println(root+" "+tree.value+":"+tree.name);
if(tree.childs==null || tree.childs.isEmpty()){
return;
}
for(TreeNode node:tree.childs){
print(node,root+root);
}
}
/**
* 树结构
*/
public static class TreeNode implements Serializable {
private String name;
private String label;
/* private List<String> labels = new ArrayList<>();*/
private List<TreeNode> childs = new ArrayList<>();
private String value;
public TreeNode(){
}
public TreeNode(String name) {
this.name = name;
}
/* public void addLabel(String label){
labels.add(label);
}*/
public TreeNode addChild(TreeNode node){
childs.add(node);
return node;
}
public TreeNode addChild(String name){
return addChild(new TreeNode(name));
}
public String getValue() {
return value;
}
public void setValue(String value) {
this.value = value;
}
@Override
public String toString() {
return name;
}
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
public List<TreeNode> getChilds() {
return childs;
}
public void setChilds(List<TreeNode> childs) {
this.childs = childs;
}
public String getLabel() {
return label;
}
public void setLabel(String label) {
this.label = label;
}
}
/**
* 加载数据
* @param dirPath
* @return
*/
public List<List<String>> loadDataSet(String dirPath){
File dir = new File(dirPath);
List<File> dsFiles = new ArrayList<>();
travelDir(dir,dsFiles);
List<List<String>> dataset = new ArrayList<>();
for(File file:dsFiles){
try {
List<String> lines = FileUtils.readLines(file,"utf-8");
for(String line:lines){
String[] lineArr = line.split(",");
dataset.add(Arrays.asList(lineArr));
}
} catch (IOException e) {
e.printStackTrace();
}
}
return dataset;
}
/**
* 遍历文件夹
* @param dir
* @param dsFiles
* @return
*/
public List<File> travelDir(File dir,List<File> dsFiles){
File[] files = dir.listFiles();
for(File file : files){
if(file.isDirectory()){
travelDir(file,dsFiles);
}else{
dsFiles.add(file);
}
}
return dsFiles;
}
public static void main(String[] args) {
DecisionTreeID3 id3 = new DecisionTreeID3();
List<List<String>> dataset = id3.createDataset();
_LOG.trace("香农熵:"+id3.calChannonEnt(dataset));
_LOG.trace("切分:"+id3.splitDataSet(dataset,0,"1"));
_LOG.trace("选择最好的特征分类:"+id3.chooseBestFetureToSplit(dataset));
TreeNode tree = id3.createTree(dataset,id3.createLabels());
id3.print(tree,"->");
List<String> testData1 = Arrays.asList("1","0");
List<String> testData2 = Arrays.asList("1","1");
List<String> testData3 = Arrays.asList("0","1");
System.out.println(testData1+":"+DecisionTreeID3.classify(tree,id3.createLabels(),testData1));
System.out.println(testData2+":"+DecisionTreeID3.classify(tree,id3.createLabels(),testData2));
System.out.println(testData3+":"+DecisionTreeID3.classify(tree,id3.createLabels(),testData3));
String dir = "D:\\workspace\\zl\\pre\\mist-parent\\mist-kbqa\\datas\\metion\\train\\train-10";
System.out.println(id3.loadDataSet(dir));
}
}