说明
此前已经上传了ID3决策树的Java实现,C4.5整体架构与之相差不大。
可参考:http://blog.youkuaiyun.com/xiaohukun/article/details/78041676
此次将结点的实现由Dom4J改为自定义类实现,更加自由和轻便。
代码已打包并上传
代码
数据仍采用ARFF格式
train.arff
@relation weather.symbolic
@attribute outlook {sunny,overcast,rainy}
@attribute temperature {hot,mild,cool}
@attribute humidity {high,normal}
@attribute windy {
TRUE,FALSE}
@attribute play {yes,no}
@data
sunny,hot,high,FALSE,no
sunny,hot,high,TRUE,no
overcast,hot,high,FALSE,yes
rainy,mild,high,FALSE,yes
rainy,cool,normal,FALSE,yes
rainy,cool,normal,TRUE,no
overcast,cool,normal,TRUE,yes
sunny,mild,high,FALSE,no
sunny,cool,normal,FALSE,yes
rainy,mild,normal,FALSE,yes
sunny,mild,normal,TRUE,yes
overcast,mild,high,TRUE,yes
overcast,hot,normal,FALSE,yes
rainy,mild,high,TRUE,no
C4.5类(主类)
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.io.FileOutputStream;
import java.io.BufferedOutputStream;
import java.lang.Math.*;
public class DecisionTree {
private ArrayList<String> train_AttributeName = new ArrayList<String>(); // 存储训练集属性的名称
private ArrayList<ArrayList<String>> train_attributeValue = new ArrayList<ArrayList<String>>(); // 存储训练集每个属性的取值
private ArrayList<String[]> trainData = new ArrayList<String[]>(); // 训练集数据 ,即arff文件中的data字符串
public static final String patternString = "@attribute(.*)[{](.*?)[}]";
//正则表达,其中*? 表示重复任意次,但尽可能少重复,防止匹配到更后面的"}"符号
private int decatt; // 决策变量在属性集中的索引(即类标所在列)
private InfoGain infoGain;
private TreeNode root;
public void train(String data_path, String targetAttr){
//模型初始化操作
read_trainARFF(new File(data_path));
//printData();
setDec(targetAttr);
infoGain=new InfoGain(trainData, decatt);
//拼装行与列
LinkedList<Integer> ll=new LinkedList<Integer>(); //LinkList用于增删比ArrayList有优势
for(int i = 0; i< train_AttributeName.size(); i++){
if(i!=decatt) ll.add(i); //防止类别变量不在最后一列发生错误
}
ArrayList<Integer> al=new ArrayList<Integer>();
for(int i=0;i<trainData.size();i++){
al.add(i);
}
//构建决策树
root = buildDT("root", "null", al, ll);
//剪枝
cutBranch(root);
}
/**
* 构建决策树
* @param fatherName 节点名称
* @param fatherValue 节点值
* @param subset 数据行子集
* @param subset 数据列子集
* @return 返回根节点
*/
public TreeNode buildDT(String fatherName, String fatherValue, ArrayList<Integer> subset,LinkedList<Integer> selatt){
TreeNode node=new TreeNode();
Map<String,Integer> targetNum = infoGain.get_AttributeNum(subset,decatt);//计算类-频率
String targetValue=infoGain.get_targetValue(targetN