Distribution.java 在划分模型选择中用来检查实例是否属于同一类
变量:
/** Weight of instances per class per bag. */
private double m_perClassPerBag[][];
/** Weight of instances per bag. */
private double m_perBag[];
/** Weight of instances per class. */
private double m_perClass[];
/** Total weight of instances. */
private double totaL;
构造方法(只建一个包):
public Distribution(Instances source) throws Exception {
m_perClassPerBag = new double [1][0];
m_perBag = new double [1];
totaL = 0;
m_perClass = new double [source.numClasses()];
m_perClassPerBag[0] = new double [source.numClasses()];
Enumeration enu = source.enumerateInstances();//这个接口具体看http://blog.youkuaiyun.com/zhiweianran/article/details/7672433
while (enu.hasMoreElements())
add(0,(Instance) enu.nextElement());
}
Distribution.add()方法
增加某个包的某个类的个数
public final void add(int bagIndex,Instance instance)
throws Exception {
int classIndex;
double weight;
classIndex = (int)instance.classValue();//实例的类值鼓捣转换为整数型
weight = instance.weight();
m_perClassPerBag[bagIndex][classIndex] =
m_perClassPerBag[bagIndex][classIndex]+weight;
m_perBag[bagIndex] = m_perBag[bagIndex]+weight;
m_perClass[classIndex] = m_perClass[classIndex]+weight;
totaL = totaL+weight;
}
总结:Distribution.java就是把实例分配到各个类中,然后知道实例在每个类中的个数。每个实例的wight一般都是默认为1,total当然就是所有实例个数之和。分包好比在类之上再添加个类。
C45Split.java在用来对属性进行划分
变量:
/** 期待的分枝数. */
private int m_complexityIndex;
/** 用来划分的属性. */
private int m_attIndex;
/** Minimum number of objects in a split. */
private int m_minNoObj;
/** Value of split point. */
private double m_splitPoint;
/** InfoGain of split. */
private double m_infoGain;
/** GainRatio of split. */
private double m_gainRatio;
/** The sum of the weights of the instances. */
private double m_sumOfWeights;
/** Number of split points. */
private int m_index;
/** Static reference to splitting criterion. */
private static InfoGainSplitCrit infoGainCrit = new InfoGainSplitCrit();
/** Static reference to splitting criterion. */
private static GainRatioSplitCrit gainRatioCrit = new GainRatioSplitCrit();
构造方法:
public C45Split(int attIndex,int minNoObj, double sumOfWeights) {
// Get index of attribute to split on.
m_attIndex = attIndex;//类的index
// Set minimum number of objects.
m_minNoObj = minNoObj;
// Set the sum of the weights
m_sumOfWeights = sumOfWeights;
}
接下来是buildClassifier(),产生一个c4.5-type的划分
public void buildClassifier(Instances trainInstances)
throws Exception {
// Initialize the remaining instance variables.
m_numSubsets = 0;
m_splitPoint = Double.MAX_VALUE;
m_infoGain = 0;
m_gainRatio = 0;
// Different treatment for enumerated and numeric
// attributes.
if (trainInstances.attribute(m_attIndex).isNominal()) {//对于名目属性
m_complexityIndex = trainInstances.attribute(m_attIndex).numValues();//分枝数
m_index = m_complexityIndex;//划分点的个数
handleEnumeratedAttribute(trainInstances);//在名目属性上进行split
}else{//对于数值属性
m_complexityIndex = 2;
m_index = 0;
trainInstances.sort(trainInstances.attribute(m_attIndex));//对实例进行排序
handleNumericAttribute(trainInstances);//在数值属性上进行split
}
}
来看看对名目属性split的handleEnumeratedAttribute():
private void handleEnumeratedAttribute(Instances trainInstances)
throws Exception {
Instance instance;
m_distribution = new Distribution(m_complexityIndex,
trainInstances.numClasses());
// 构造函数public Distribution(int numBags,int numClasses) 每个分枝是一个包
// Only Instances with known values are relevant.
Enumeration enu = trainInstances.enumerateInstances();
while (enu.hasMoreElements()) {
instance = (Instance) enu.nextElement();
if (!instance.isMissing(m_attIndex))
m_distribution.add((int)instance.value(m_attIndex),instance);
}
// Check if minimum number of Instances in at least two
// subsets.
if (m_distribution.check(m_minNoObj)) {
m_numSubsets = m_complexityIndex;
m_infoGain = infoGainCrit.//信息增益
splitCritValue(m_distribution,m_sumOfWeights);
m_gainRatio =
gainRatioCrit.splitCritValue(m_distribution,m_sumOfWeights,
m_infoGain);
}
}
返回信息增益:
public final double infoGain() {
return m_infoGain;
}
在给定的数据集中,让分割点取得最大值,小于或等于老的分割点(针对数值属性)
setSplitPoint()
public final void setSplitPoint(Instances allInstances) {
double newSplitPoint = -Double.MAX_VALUE;
double tempValue;
Instance instance;
if ((allInstances.attribute(m_attIndex).isNumeric()) &&
(m_numSubsets > 1)) {
Enumeration enu = allInstances.enumerateInstances();
while (enu.hasMoreElements()) {
instance = (Instance) enu.nextElement();
if (!instance.isMissing(m_attIndex)) {
tempValue = instance.value(m_attIndex);
if (Utils.gr(tempValue,newSplitPoint) &&
Utils.smOrEq(tempValue,m_splitPoint))
newSplitPoint = tempValue;
}
}
m_splitPoint = newSplitPoint;
}
}
总结:C45Split.java用来找到可以使增益率提高的属性作为一个划分,也就是所谓的模型