接着上面说下决策树的一些其他算法:SLIQ、SPRINT、CART。这些算法则是根据Gini指标来计算的。
SLIQ
SLIQ(Supervised Learning In Quest)利用三中数据结构来构造树,分别是属性表、类表和类直方图。
SLIQ算法在建树阶段,对连续属性采取预先排序技术与广度优先相结合的策略生成树,对离散属性采取快速求子集算法确定划分条件。
具体步骤如下:
step1:建立类表和各个属性表,并且进行预先排序,即对每个连续属性的属性表进行独立的排序,以避免在每个节点上都要给连续属性值重新排序;
step2:如果每个叶子节点中的样本都能归为一类,则算法停止;否则转step3;
step3:利用属性表计算gini值,选择最小gini值的属性和分割点作为最佳划分;
step4:根据step3得到的最佳划分节点,判断为真的样本划分为左孩子节点,否则划分为右孩子节点.这样就构成了广度优先的生成树策略;
step5:更新类表中的第二项,使之指向样本划分后所在的叶子节点;
step6:跳转到step2
分类回归树算法:CART(Classification And Regression Tree)算法采用一种二分递归分割的技术,将当前的样本集分为两个子样本集,使得生成的的每个非叶子节点都有两个分支。因此,CART算法生成的决策树是结构简洁的二叉树。
分类树两个基本思想:第一个是将训练样本进行递归地划分自变量空间进行建树的想法,第二个想法是用验证数据进行剪枝。
这里只对SPRINT算法用Java进行了简单实现
@Override
public Object build(Data data) {
//对数据集预先判断,特征属性为空时候选取最多数量的类型,数据集全部为统一类型时候直接返回类型
Object preHandleResult = preHandle(data);
if (null != preHandleResult) return preHandleResult;
//创建属性表
Map<String, List<Attribute>> attributeTableMap =
new HashMap<String, List<Attribute>>();
for (Instance instance : data.getInstances()) {
String category = String.valueOf(instance.getCategory());
Map<String, Object> attrs = instance.getAttributes();
for (Map.Entry<String, Object> entry : attrs.entrySet()) {
String attrName = entry.getKey();
List<Attribute> attributeTable = attributeTableMap.get(attrName);
if (null == attributeTable) {
attributeTable = new ArrayList<Attribute>();
attributeTableMap.put(attrName, attributeTable);
}
attributeTable.add(new Attribute(instance.getId(),
attrName, String.valueOf(entry.getValue()), category));
}
}
//计算属性表的基尼指数
Set<String> attributes = data.getAttributeSet();
String splitAttribute = null;
String minSplitPoint = null;
double minSplitPointGini = 1.0;
for (Map.Entry<String, List<Attribute>> entry : attributeTableMap.entrySet()) {
String attribute = entry.getKey();
if (!attributes.contains(attribute)) {
continue;
}
List<Attribute> attributeTable = entry.getValue();
Object[] result = calculateMinGini(attributeTable);
double splitPointGini = Double.parseDouble(String.valueOf(result[1]));
if (minSplitPointGini > splitPointGini) {
minSplitPointGini = splitPointGini;
minSplitPoint = String.valueOf(result[0]);
splitAttribute = attribute;
}
}
System.out.println("splitAttribute: " + splitAttribute);
TreeNode treeNode = new TreeNode(splitAttribute);
//根据分割属性和分割点分割数据集
attributes.remove(splitAttribute);
Set<String> attributeValues = new HashSet<String>();
List<List<Instance>> splitInstancess = new ArrayList<List<Instance>>();
List<Instance> splitInstances1 = new ArrayList<Instance>();
List<Instance> splitInstances2 = new ArrayList<Instance>();
splitInstancess.add(splitInstances1);
splitInstancess.add(splitInstances2);
for (Instance instance : data.getInstances()) {
Object value = instance.getAttribute(splitAttribute);
attributeValues.add(String.valueOf(value));
if (value.equals(minSplitPoint)) {
splitInstances1.add(instance);
} else {
splitInstances2.add(instance);
}
}
attributeValues.remove(minSplitPoint);
StringBuilder sb = new StringBuilder();
for (String attributeValue : attributeValues) {
sb.append(attributeValue).append(",");
}
if (sb.length() > 0) sb.deleteCharAt(sb.length() - 1);
String[] names = new String[]{minSplitPoint, sb.toString()};
for (int i = 0; i < 2; i++) {
List<Instance> splitInstances = splitInstancess.get(i);
if (splitInstances.size() == 0) continue;
Data subData = new Data(attributes.toArray(new String[0]),
splitInstances);
treeNode.setChild(names[i], build(subData));
}
return treeNode;
}
/** 计算基尼指数*/
public Object[] calculateMinGini(List<Attribute> attributeTable) {
double totalNum = 0.0;
Map<String, Map<String, Integer>> attrValueSplits =
new HashMap<String, Map<String, Integer>>();
Set<String> splitPoints = new HashSet<String>();
Iterator<Attribute> iterator = attributeTable.iterator();
while (iterator.hasNext()) {
Attribute attribute = iterator.next();
String attributeValue = attribute.getValue();
splitPoints.add(attributeValue);
Map<String, Integer> attrValueSplit = attrValueSplits.get(attributeValue);
if (null == attrValueSplit) {
attrValueSplit = new HashMap<String, Integer>();
attrValueSplits.put(attributeValue, attrValueSplit);
}
String category = attribute.getCategory();
Integer categoryNum = attrValueSplit.get(category);
attrValueSplit.put(category, null == categoryNum ? 1 : categoryNum + 1);
totalNum++;
}
String minSplitPoint = null;
double minSplitPointGini = 1.0;
for (String splitPoint : splitPoints) {
double splitPointGini = 0.0;
double splitAboveNum = 0.0;
double splitBelowNum = 0.0;
Map<String, Integer> attrBelowSplit = new HashMap<String, Integer>();
for (Map.Entry<String, Map<String, Integer>> entry : attrValueSplits.entrySet()){
String attrValue = entry.getKey();
Map<String, Integer> attrValueSplit = entry.getValue();
if (splitPoint.equals(attrValue)) {
for (Integer v : attrValueSplit.values()) {
splitAboveNum += v;
}
double aboveGini = 1.0;
for (Integer v : attrValueSplit.values()) {
aboveGini -= Math.pow((v / splitAboveNum), 2);
}
splitPointGini += (splitAboveNum / totalNum) * aboveGini;
} else {
for (Map.Entry<String, Integer> e : attrValueSplit.entrySet()) {
String k = e.getKey();
Integer v = e.getValue();
Integer count = attrBelowSplit.get(k);
attrBelowSplit.put(k, null == count ? v : v + count);
splitBelowNum += e.getValue();
}
}
}
double belowGini = 1.0;
for (Integer v : attrBelowSplit.values()) {
belowGini -= Math.pow((v / splitBelowNum), 2);
}
splitPointGini += (splitBelowNum / totalNum) * belowGini;
if (minSplitPointGini > splitPointGini) {
minSplitPointGini = splitPointGini;
minSplitPoint = splitPoint;
}
}
return new Object[]{minSplitPoint, minSplitPointGini};
}