C4.5的学习
写的很简单,可以看下面的博客资源
http://www.cnblogs.com/biyeymyhjob/archive/2012/07/23/2605208.html
http://computerght.blog.163.com/blog/static/522736862013102992416752/
C4.5算法相关资料
C4.5算法主要基于ID3算法改进,ID3算法利用信息增益构建决策树,而C4.5利用信息增益率来构建决策树,并且加入了树枝修剪来排除过拟合。
以下公式用于计算所需信息。
- Info(D) = -∑pi*log2(pi) ; 对类标签计算信息熵
- InfoA(D) = -∑ Di/D*Info(Di); 对属性A计算信息熵
- Gain(A) = Info(D) - InfoA(D); 对属性A计算信息增益
ID3算法利用信息增益构建决策树,信息增益越大,则考虑以此属性进行分裂
C4.5算法考虑到信息增益偏向于大量值的属性,因此考虑利用信息增益率来选择属性作为分裂点。
- SplitInfoA(D) = -∑(Di/D)*log2(Di/D); 分裂信息
- GainRatioA(D) = Gain(A)/SplitInfoA(D); 信息增益率
以上C4.5算法需要用到的公式。
后面使用UCI数据库中的数据,用C语言实现,(代码全部为原创,可能有点乱,后面根据情况可能会进一步修改,将得到的决策树用于学习,可以得到70%左右的正确率)
根据结果小节:
1.程序中采用的树的深度限制树过于复杂(可以采用其他思路进行修剪)。
2. 训练数据集采用了前1600(总共1728)个样本进行,更好的想法应该是随机选取数据。
只有两个代码文件,没有将函数分开,功能都在main.h实现,具体的用例在main.cpp中。
<main.h>
<pre name="code" class="cpp">/****************************
DATA: Car Evaluation Database(from UCI)
决策树学习算法C实现(彭杰)
Copyright 2015/3/21 owner by pengjie
All rights reserved
*******************************/
#ifndef MAIN_H
#define MAIN_H
#include "stdio.h"
#include "stdlib.h"
#include <map>
#include <iostream>
#include <string>
#include <string.h>
#include <vector>
#include <math.h>
#define maxClass 10
#define maxDeep 3
using namespace std;
typedef map<string, char> StringCharMap;
int locp = 0;//tree point location in the array
struct attrPoint
{
int attrName;
char attrValue;
};
struct DesTreePoint
{
attrPoint prepos;//previous point
int curpos;//current point
int deep;//current deep
int nextPointNum;
//attrPoint nextattr[maxClass];//next attrname,0:none next point
char label;//classify
int loca;//
int preloca;
int nextloca[maxClass];
};
vector<DesTreePoint> treeQueue;
void print(DesTreePoint point)
{
printf("prepos--name,value:%d,%c\n",point.prepos.attrName,point.prepos.attrValue);
printf("curpos: %d\n",point.curpos);
printf("deep: %d\n",point.deep);
printf("nextPointNum: %d\n",point.nextPointNum);
printf("loca: %d\n",point.loca);
printf("preloca:%d\n",point.preloca);
}
int ReadData(char **data, int attriSome[], StringCharMap *mapAttr, char *file,const int dataline,const int attrinum)
{
FILE *pFile;
char buf[256];
pFile = fopen(file,"rt");
if(pFile==NULL)
{
printf("the data file is not existing: %s\n", file);
return -1;
}
int row = 0; //data line
int cloumn = 0; //data attribute
char delim[] = ",";//data delimiter
string tmpdata;//data cache
while(!feof(pFile)&&row<dataline)
{
fgets(buf,256,pFile);
//printf("%s\t %d\n",buf,row);
/*printf("%d-%c\n",strlen(buf),buf[strlen(buf)-1]);*/
/*buf[strlen(buf)-1]=='\n';*/
if( buf[strlen(buf)-1]=='\n' )
{
buf[strlen(buf)-1]='\0';
}
for(cloumn=0;cloumn<attrinum;++cloumn )
{
if( cloumn==0 )
{
tmpdata = strtok(buf,delim);
//tmpdata[strlen(tmpdata)] = '\0';
//printf("%s\t",tmpdata.c_str());
data[row][cloumn] = mapAttr[cloumn][tmpdata];
}
else
{
tmpdata = strtok(NULL,delim);
//tmpdata[strlen(tmpdata)] = '\0';
//printf("%s,%d\t",tmpdata.c_str(),strlen(tmpdata.c_str()));
data[row][cloumn] = mapAttr[cloumn][tmpdata];
}
}
//printf("\n");
++row;
}
return 1;
}
double GainRatio( char **data,const int datasize, const int attrinum,int deep,attrPoint *attrp,int *attriSome,int attrname,int &classfiedLabel )
{
int labelnum = attriSome[attrname];//attribute num named 'attrname'
int classifynum = attriSome[attrinum-1];//classify num
int **attrFInfoDA = new int*[labelnum];
for(int i=0;i<labelnum;++i)
attrFInfoDA[i] = new int[classifynum+1];//for InfoDA
int *attrFSplitInfo =new int[labelnum+1];//for SplitInfo
int *classify = new int[classifynum+1];//for InfoD
double infoDA = 0.0;
double splitInfo = 0.0;
double infoD = 0.0;
//initial
for(int i=0;i<labelnum;++i)
for(int j=0;j<=classifynum;++j)
attrFInfoDA[i][j] = 0;
for(int i=0;i<=labelnum;++i)
attrFSplitInfo[i] = 0;
for(int i=0;i<=classifynum;++i)
classify[i] = 0;
//get data
if( deep>0 )
{
for(int i=0;i<datasize;++i )
{
int flag = 0;
for(int j=0;j<deep;++j)
{
int tmpattr = attrp[j].attrName;
if( data[i][tmpattr]==attrp[j].attrValue )
{
++flag;
continue;
}
else
j = deep;
}
if(flag==deep)
{
int attrlabel = data[i][attrname]-48;
int classifylabel = data[i][attrinum-1]-48;
++attrFInfoDA[attrlabel][classifylabel];
++attrFSplitInfo[attrlabel];
++classify[classifylabel];
}
}
}
else
{
for(int i=0;i<datasize;++i )
{
int attrlabel = data[i][attrname]-48;
int classifylabel = data[i][attrinum-1]-48;
++attrFInfoDA[attrlabel][classifylabel];
++attrFSplitInfo[attrlabel];
++classify[classifylabel];
}
}
//calculate
//printf("classfy: \n");
for(int i=0; i<classifynum; ++i)
{
classify[classifynum] += classify[i];
}
//printf(" %d\n",classify[classifynum] );
//printf("attrFInfoDA: \n");
for(int i=0;i<labelnum;++i)
{
for(int j=0;j<classifynum;++j)
{
attrFInfoDA[i][classifynum] += attrFInfoDA[i][j];
}
//printf(" %d\n",attrFInfoDA[i][classifynum] );
}
//printf("attrFSplitInfo:\n");
for(int i=0; i<labelnum; ++i)
{
attrFSplitInfo[labelnum] += attrFSplitInfo[i];
//printf(" %d\n",attrFSplitInfo[i] );
}
//printf(" %d\n",attrFSplitInfo[labelnum] );
//infoD
double maxpi = 0.0;
int maxindex = 0;
for(int i=0; i<classifynum; ++i)
{
double pi = double(classify[i])/classify[classifynum];
//printf(" pi %d: %f\n",i,pi);
if(pi>maxpi)
{
maxpi = pi;
maxindex = i;
}
if(pi<0.000001)
infoD += 0.0;
else
infoD +=(-1*pi*log(pi)/log(2.0));
}
if (fabs(infoD)<0.0000001||maxpi>0.95)
{
classfiedLabel = maxindex;
return (0.0);
}
else
{
classfiedLabel = maxindex;
}
//printf("infoD: %f \n",infoD );
//infoDA
double infoDj = 0.0;
for( int i=0;i<labelnum;++i )
{
for( int j=0;j<classifynum;++j )
{
double pj = double(attrFInfoDA[i][j])/attrFInfoDA[i][classifynum];
//printf(" pj_%d_%d: %f\n",i,j, pj);
if(pj<0.000001)
infoDj = 0.0;
else
infoDj += (-1*pj*log(pj)/log(2.0));
}
infoDA += double(attrFSplitInfo[i])/attrFSplitInfo[labelnum]*infoDj;
//printf(" infoDj_%d: %f\n",i, infoDj);
infoDj = 0.0;
}
//printf(" infoDA: %f\n", infoDA);
//splitInfo
for( int i=0;i<labelnum;++i)
{
double ps = double(attrFSplitInfo[i])/attrFSplitInfo[labelnum];
if(ps<0.000001)
splitInfo += 0.0;
else
splitInfo += (-ps*log(ps)/log(2.0));
}
//printf(" splitInfo: %f\n", splitInfo);
return ( (infoD-infoDA)/splitInfo );
}
void CreateDecisionTree( DesTreePoint *tree,const int datasize, const int attrinum, char **data,int *attriSome )
{
double *ration = new double[attrinum] ;
bool *lockAttr = new bool[attrinum];//lock the already check attribute column: 0,non-lock; 1,lock;
int maxRation = 0;
double maxV = -1;
int classfiedLabel = -1;//if it's leaf point,classfiedLabel get the final classfied label
//get root point
for(int i=0;i<(attrinum-1);++i)
{
ration[i] = GainRatio(data,datasize,attrinum,0,NULL,attriSome,i,classfiedLabel);
//printf("%d: %f \n",i,ration[i]);
}
for(int i=1;i<(attrinum-1);++i)
{
if(ration[i]>ration[maxRation])
maxRation = i;
}
tree[0].prepos.attrName = -1;
tree[0].curpos = maxRation;
tree[0].deep = 0;
tree[0].loca = 0;
tree[0].label = '#';
tree[0].preloca = -1;
tree[0].nextPointNum = attriSome[maxRation];
//print(tree[0]);
for(int i=0;i<tree[0].nextPointNum;++i)
{
++locp;
tree[0].nextloca[i] = locp;
tree[locp].prepos.attrName = tree[0].curpos;
tree[locp].loca = locp;
tree[locp].prepos.attrValue = i+48;
tree[locp].deep = tree[0].deep+1;
tree[locp].preloca = tree[0].loca;
tree[locp].label = '#';
treeQueue.push_back(tree[locp]);
}
attrPoint tmpTree[maxClass];
DesTreePoint tmpPoint;
DesTreePoint staPoint;
attrPoint tmpAttr;
int tmpp = 0;
while( !treeQueue.empty() )
{
//initial
for( int i=0;i<attrinum;++i )
lockAttr[i] = 0;
tmpPoint = treeQueue.back();
staPoint = tmpPoint;
//print(tmpPoint);
treeQueue.pop_back();
tmpAttr = staPoint.prepos;
tmpp = 0;
while( tmpAttr.attrName!=-1 )
{
tmpTree[tmpp] = tmpAttr;
lockAttr[tmpAttr.attrName] = 1;
++tmpp;
staPoint = tree[staPoint.preloca];
tmpAttr = tree[staPoint.loca].prepos;
//print(tmpPoint);
}
//printf(" equal to deep %d\n",tmpp );
for(int i=0;i<(attrinum-1);++i)
ration[i] = 0;
bool isleaf = 0;
for(int i=0;i<(attrinum-1);++i)
{
if( lockAttr[i]==1 )
continue;
//printf("%d: %f \n",i,ration[i]);
ration[i] = GainRatio(data,datasize,attrinum,tmpPoint.deep,tmpTree,attriSome,i,classfiedLabel);
//printf("%d: %f \n",i,ration[i]);
if( ration[i]<0.000001 || tmpPoint.deep>maxDeep)
{
tree[tmpPoint.loca].label = classfiedLabel+48;
isleaf = 1;
i = attrinum;
}
}
classfiedLabel = -1; //go to the initial state
maxRation = 0;
if(!isleaf)
{
for(int i=1;i<(attrinum-1);++i)
{
if( lockAttr[i]==1 )
continue;
if(ration[i]>maxV)
{
maxRation = i;
maxV = ration[i];
}
}
tree[tmpPoint.loca].curpos = maxRation;
tree[tmpPoint.loca].nextPointNum = attriSome[maxRation];
//tree[tmpPoint.loca].label = '#';
for(int i=0; i<tree[tmpPoint.loca].nextPointNum;++i)
{
++locp;
tree[tmpPoint.loca].nextloca[i] = locp;
tree[locp].prepos.attrName = tree[tmpPoint.loca].curpos;
tree[locp].deep = tree[tmpPoint.loca].deep+1;
tree[locp].prepos.attrValue = i+48;
tree[locp].loca = locp;
tree[locp].label = '#';
tree[locp].preloca = tree[tmpPoint.loca].loca;
treeQueue.push_back(tree[locp]);
}
}//if isleaf
}//while
delete []ration;
}
bool predict(DesTreePoint *tree,char *data, const int attrinum )
{
int locp = 0;
int tmploc = 0;
char label = '#';
DesTreePoint curPoint;
curPoint = tree[0];
if( tree[0].label!='#' )
{
if(tree[0].label==data[attrinum-1])
return 1;
else
return 0;
}
while( curPoint.label=='#' )
{
locp = curPoint.loca;
for(int i=0;i<tree[locp].nextPointNum;++i)
{
tmploc = tree[locp].nextloca[i];
int attrName = tree[tmploc].prepos.attrName;
if(data[attrName]==tree[tmploc].prepos.attrValue)
curPoint = tree[tmploc];
}
}
if(curPoint.label == data[attrinum-1])
return 1;
else
return 0;
}
#endif
<main.cpp>
/****************************
DATA: Car Evaluation Database(from UCI)
决策树学习算法C实现(彭杰)
Copyright 2015/3/21 owner by pengjie
All rights reserved
*******************************/
#include "main.h"
const int dataline = 1728;
const int attrinum = 7;
const int maxTree = 1000;
void SetMapAttr( StringCharMap *mapAttr );
int main()
{
DesTreePoint *tree = new DesTreePoint[maxTree];
//pre-input
char *file = "C:\\Users\\Administrator\\Desktop\\machine_data\\car.data";
int attriSome[attrinum] = {4,4,4,3,3,3,4};
StringCharMap mapAttr[attrinum];
SetMapAttr( mapAttr );
int train = 1650; //the number of train data
char **data = new char*[dataline];
for(int i=0;i<dataline;++i)
data[i] = new char[attrinum];
if( -1!=ReadData(data,attriSome,mapAttr,file,dataline,attrinum) )
{
CreateDecisionTree( tree,train, attrinum, data,attriSome );
}
int correct = 0;
int sum = 0;
for(int i=(train+1);i<dataline; ++i)
{
++sum;
bool eva = predict(tree,data[i],attrinum);
if(eva)
++correct;
}
double rp = double(correct)/sum;
printf("the right correction: %f\n",rp);
for(int i=0;i<dataline;++i)
delete []data[i];
delete []data;
delete []tree;
return 0;
}
void SetMapAttr( StringCharMap *mapAttr )
{
mapAttr[0]["vhigh"] = '0';
mapAttr[0]["high"] = '1';
mapAttr[0]["med"] = '2';
mapAttr[0]["low"] = '3';
mapAttr[1]["vhigh"] = '0';
mapAttr[1]["high"] = '1';
mapAttr[1]["med"] = '2';
mapAttr[1]["low"] = '3';
mapAttr[2]["2"] = '0';
mapAttr[2]["3"] = '1';
mapAttr[2]["4"] = '2';
mapAttr[2]["5more"] = '3';
mapAttr[3]["2"] = '0';
mapAttr[3]["4"] = '1';
mapAttr[3]["more"] = '2';
mapAttr[4]["small"] = '0';
mapAttr[4]["med"] = '1';
mapAttr[4]["big"] = '2';
mapAttr[5]["low"] = '0';
mapAttr[5]["med"] = '1';
mapAttr[5]["high"] = '2';
mapAttr[6]["unacc"] = '0';
mapAttr[6]["acc"] = '1';
mapAttr[6]["good"] = '2';
mapAttr[6]["vgood"] = '3';
}
附件中添加了所用的数据及数据说明。数据来自UCI数据库。实在没有看到在哪里上传数据,可以在以下资源里面下载所有相关文件。
http://download.youkuaiyun.com/detail/u200812705/8520817