分类的概念
分类的基本任务就是根据给定的一系列属性集,最后去判别它属于的类型!
比如我们现在需要去给动物分类,类别可选项为哺乳类,爬行类,鸟类,鱼类,或者两栖类。给你一些属性集如这个动物的体温,是否胎生,是否为水生动物,是否为飞行动物,是否有腿,是否冬眠。
现在分类的基本任务就是,已知一个动物的属性集,判断或预测这个动物属于哪一种类别?
决策树分类法
简述
从根节点开始,每个分支都会包含一个属性测试条件,用于分开具有不同特性的记录,最终到达叶节点,即可得到类标号。
具体过程
从根节点开始,从众多的属性集里边选择一个属性,由这个属性把数据进行分类(该属性的一个值则形成一个孩子节点),得到这个根节点的多个孩子节点。
再由这些孩子节点开始选择剩余的属性来进行分类,递归的进行下去,直至所有属性都已经使用完毕!
问题
(1). 如何确定选择哪个属性来作为测试条件?
某个分类的熵值定义为:
![]()
所以对于一个属性来说,分类后的熵值越低说明数据的纯度越高,这个正是我们想要得到的结果,故使用这个指标来判断属性的优先选择权。
(2). 如何终止递归?避免过度拟合?
数据中可能会出现一些离群点,这会造成决策树在进行决策的过程中对这样的数据非常敏感,所以我们可以使用一个阈值来终止递归(即当前的节点下数据标号的纯度已经满足某个阈值)。
关键代码
private void buildDecisionTree(AttrNode node, String parentAttrValue, String[][] remainData, ArrayList<String> remainAttr, boolean isID3) {
node.setParentAttrValue(parentAttrValue);
String attrName = "";
double gainValue = 0;
double tempValue = 0;
if(remainAttr.size() == 1) {
System.out.println("attr null");
return ;
}
// 在所有剩余属性集里选择一个信息增益最大的属性
for(int i = 0;i < remainAttr.size();i ++) {
if(isID3) {
// ID3算法计算信息增益
tempValue = computeGain(remainData, remainAttr.get(i));
} else {
// C4.5算法计算信息增益比
tempValue = computeGainRatio(remainData, remainAttr.get(i));
}
if(tempValue > gainValue) {
gainValue = tempValue;
// 找到最佳的属性
attrName = remainAttr.get(i);
}
}
node.setAttrName(attrName);
// 得到这个属性下的所有取值 去进一步拓展孩子节点
ArrayList<String> valueTypes = attrValue.get(attrName);
// 移除掉这个已经使用了的属性
remainAttr.remove(attrName);
AttrNode[] childNode = new AttrNode[valueTypes.size()];
String[][] rData;
// 遍历这个属性的所有取值
for(int i = 0;i < valueTypes.size();i ++) {
// 把该种取值下的数据提取出来
rData = removeData(remainData, attrName, valueTypes.get(i));
childNode[i] = new AttrNode();
boolean sameClass = true;
ArrayList<String> indexArray = new ArrayList<>();
// 遍历剩余的数据
for(int k = 1;k < rData.length;k ++) {
indexArray.add(rData[k][0]);
if (!rData[k][attrNames.length - 1].equals(rData[1][attrNames.length - 1])) {
sameClass = false;
break;
}
}
if(!sameClass) {
buildDecisionTree(childNode[i], valueTypes.get(i), rData, remainAttr, isID3);
} else {
// 如果数据中标号全部相同(或者是达到了某个阈值)停止递归
childNode[i].setParentAttrValue(valueTypes.get(i));
childNode[i].setChildDataIndex(indexArray);
}
}
// 递归完成后,给头结点设定孩子节点
node.setChildAttrNode(childNode);
}
总结
决策树分类算法是属于监督学习的算法,也就是他需要初始的数据来进行训练,去得到一个经过训练的模型。然后这个模型就可以用来根据属性集预测标号。它的不足在于它无法进行增量计算,也就是当新增一些已知的数据集的时候,只有重新结合之前的数据来重新构建决策树,而无法仅仅利用增量来构建强化。但是这类算法的思路非常简单,理解起来也不难。
引申
CART算法(Classification And Regression Tree):也是一种决策树分类算法,与之前的C4.5和ID3不同的是:
1. 每个非叶子节点都有两个孩子节点,这也就意味着划分条件仅为等于和不等于某个值,来对数据进行划分空间。
2. CART算法对于属性的值采用的是基于Gini系数值的方式做比较,举一个网上的一个例子:(划分条件为体温是否恒温)
比如体温为恒温时包含哺乳类5个、鸟类2个,则:
Gini(left_child)=1−(