机器学习算法之C4.5(C语言实现)

本文介绍了一种基于C4.5算法的决策树构建方法,并通过C语言实现了该算法。通过对UCI数据库中的CarEvaluation数据集进行训练,构建了一个能够达到约70%准确率的决策树模型。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

C4.5的学习

写的很简单,可以看下面的博客资源

http://www.cnblogs.com/biyeymyhjob/archive/2012/07/23/2605208.html

http://computerght.blog.163.com/blog/static/522736862013102992416752/

 

C4.5算法相关资料

 C4.5算法主要基于ID3算法改进,ID3算法利用信息增益构建决策树,而C4.5利用信息增益率来构建决策树,并且加入了树枝修剪来排除过拟合。

以下公式用于计算所需信息。

  1.  Info(D)  =  -∑pi*log2(pi) ;  对类标签计算信息熵
  2.  InfoA(D) =  -∑ Di/D*Info(Di); 对属性A计算信息熵
  3. Gain(A) = Info(D) - InfoA(D);  对属性A计算信息增益

ID3算法利用信息增益构建决策树,信息增益越大,则考虑以此属性进行分裂

C4.5算法考虑到信息增益偏向于大量值的属性,因此考虑利用信息增益率来选择属性作为分裂点。

  1. SplitInfoA(D) = -∑(Di/D)*log2(Di/D);  分裂信息
  2. GainRatioA(D) = Gain(A)/SplitInfoA(D); 信息增益率

以上C4.5算法需要用到的公式。

后面使用UCI数据库中的数据,用C语言实现,(代码全部为原创,可能有点乱,后面根据情况可能会进一步修改,将得到的决策树用于学习,可以得到70%左右的正确率)

根据结果小节:

1.程序中采用的树的深度限制树过于复杂(可以采用其他思路进行修剪)。

2. 训练数据集采用了前1600(总共1728)个样本进行,更好的想法应该是随机选取数据。

只有两个代码文件,没有将函数分开,功能都在main.h实现,具体的用例在main.cpp中。

<main.h>

<pre name="code" class="cpp">/****************************

DATA: Car Evaluation Database(from UCI)

决策树学习算法C实现(彭杰)

Copyright 2015/3/21 owner by pengjie
All rights reserved

*******************************/


#ifndef MAIN_H
#define MAIN_H

#include "stdio.h"
#include "stdlib.h"
#include <map>
#include <iostream>
#include <string>
#include <string.h>
#include <vector>
#include <math.h>

#define maxClass 10
#define maxDeep 3
using namespace std;

typedef map<string, char> StringCharMap;
int locp = 0;//tree point location in the array


struct attrPoint
{
	int attrName;
	char attrValue;
};
struct DesTreePoint
{
	attrPoint prepos;//previous point
	int curpos;//current point
	int deep;//current deep
	int nextPointNum;
	//attrPoint nextattr[maxClass];//next attrname,0:none next point
	char label;//classify
	int loca;//
	int preloca;
	int nextloca[maxClass];
};
vector<DesTreePoint> treeQueue;

void print(DesTreePoint point)
{
	printf("prepos--name,value:%d,%c\n",point.prepos.attrName,point.prepos.attrValue);
	printf("curpos: %d\n",point.curpos);
	printf("deep: %d\n",point.deep);
	printf("nextPointNum: %d\n",point.nextPointNum);
	printf("loca: %d\n",point.loca);
	printf("preloca:%d\n",point.preloca);
}

int ReadData(char **data, int attriSome[], StringCharMap *mapAttr, char *file,const int dataline,const int attrinum)
{
	FILE *pFile;
	char buf[256];
	pFile = fopen(file,"rt");
	if(pFile==NULL)
	{
		printf("the data file is not existing: %s\n", file);
		return -1;
	}
	
	int row = 0;    //data line
	int cloumn = 0; //data attribute
	char delim[] = ",";//data delimiter
	string tmpdata;//data cache
	while(!feof(pFile)&&row<dataline)
	{
		fgets(buf,256,pFile);

		//printf("%s\t  %d\n",buf,row);
		/*printf("%d-%c\n",strlen(buf),buf[strlen(buf)-1]);*/
		/*buf[strlen(buf)-1]=='\n';*/
		if( buf[strlen(buf)-1]=='\n' )
		{
			buf[strlen(buf)-1]='\0';
			
		}

		for(cloumn=0;cloumn<attrinum;++cloumn )
		{ 
			if( cloumn==0 )
			{
				tmpdata = strtok(buf,delim);
				//tmpdata[strlen(tmpdata)] = '\0';
				//printf("%s\t",tmpdata.c_str());

				data[row][cloumn] = mapAttr[cloumn][tmpdata];
			}
			else
			{
				tmpdata = strtok(NULL,delim);
				//tmpdata[strlen(tmpdata)] = '\0';
				//printf("%s,%d\t",tmpdata.c_str(),strlen(tmpdata.c_str()));
				data[row][cloumn] = mapAttr[cloumn][tmpdata];
			}
			
		}
		//printf("\n");
		++row;
	}

	return 1;

}

double GainRatio( char **data,const int datasize, const int attrinum,int deep,attrPoint *attrp,int *attriSome,int attrname,int &classfiedLabel )
{
	int labelnum = attriSome[attrname];//attribute num named 'attrname'
	int classifynum = attriSome[attrinum-1];//classify num

	int **attrFInfoDA = new int*[labelnum];
	for(int i=0;i<labelnum;++i)
		attrFInfoDA[i] = new int[classifynum+1];//for InfoDA
	int *attrFSplitInfo =new int[labelnum+1];//for SplitInfo
	int *classify = new int[classifynum+1];//for InfoD
	double infoDA = 0.0;
	double splitInfo = 0.0;
	double infoD = 0.0;

	//initial
	for(int i=0;i<labelnum;++i)
		for(int j=0;j<=classifynum;++j)
			attrFInfoDA[i][j] = 0;

	for(int i=0;i<=labelnum;++i)
		attrFSplitInfo[i] = 0;

	for(int i=0;i<=classifynum;++i)
		classify[i] = 0;

	//get data
	if( deep>0 )
	{
		for(int i=0;i<datasize;++i )
		{
			int flag = 0;
			for(int j=0;j<deep;++j)
			{
				int tmpattr = attrp[j].attrName;
				if( data[i][tmpattr]==attrp[j].attrValue )
				{
					++flag;
					continue;
				}
				else
					j = deep;

			}
			if(flag==deep)
			{

				int attrlabel = data[i][attrname]-48;
				int classifylabel = data[i][attrinum-1]-48;
				++attrFInfoDA[attrlabel][classifylabel];
				++attrFSplitInfo[attrlabel];
				++classify[classifylabel];
			}
		}
	}
	else
	{
		for(int i=0;i<datasize;++i )
		{
			int attrlabel = data[i][attrname]-48;
			int classifylabel = data[i][attrinum-1]-48;
			++attrFInfoDA[attrlabel][classifylabel];
			++attrFSplitInfo[attrlabel];
			++classify[classifylabel];
		}
	}

	//calculate

	//printf("classfy:   \n");
	for(int i=0; i<classifynum; ++i)
	{
		classify[classifynum] += classify[i];
		
	}
	//printf(" %d\n",classify[classifynum] );

	//printf("attrFInfoDA: \n");
	for(int i=0;i<labelnum;++i)
	{
		for(int j=0;j<classifynum;++j)
		{
			attrFInfoDA[i][classifynum] += attrFInfoDA[i][j];
		}
		//printf(" %d\n",attrFInfoDA[i][classifynum] );
	}


	//printf("attrFSplitInfo:\n");
	for(int i=0; i<labelnum; ++i)
	{
		attrFSplitInfo[labelnum] += attrFSplitInfo[i];
		//printf(" %d\n",attrFSplitInfo[i] );
	}
	//printf(" %d\n",attrFSplitInfo[labelnum] );

	//infoD
	double maxpi = 0.0;
	int maxindex = 0;
	for(int i=0; i<classifynum; ++i)
	{
		double pi = double(classify[i])/classify[classifynum];
		//printf(" pi %d: %f\n",i,pi);
		if(pi>maxpi)
		{
			maxpi = pi;
			maxindex = i;
		}

		if(pi<0.000001)
			infoD += 0.0;
		else
			infoD +=(-1*pi*log(pi)/log(2.0));
	}

	if (fabs(infoD)<0.0000001||maxpi>0.95)
	{

		classfiedLabel = maxindex;
		return (0.0);
	}
	else
	{
		classfiedLabel = maxindex;
	}

	//printf("infoD: %f \n",infoD );

	//infoDA
	double infoDj = 0.0;
	for( int i=0;i<labelnum;++i )
	{
		for( int j=0;j<classifynum;++j )
		{
			double pj = double(attrFInfoDA[i][j])/attrFInfoDA[i][classifynum];
			//printf(" pj_%d_%d: %f\n",i,j, pj);
			if(pj<0.000001)
				infoDj = 0.0;
			else
				infoDj += (-1*pj*log(pj)/log(2.0));
		}
		infoDA += double(attrFSplitInfo[i])/attrFSplitInfo[labelnum]*infoDj;
		//printf(" infoDj_%d: %f\n",i, infoDj);
		infoDj = 0.0;
	}
	//printf(" infoDA: %f\n", infoDA);

	//splitInfo
	for( int i=0;i<labelnum;++i)
	{
		double ps = double(attrFSplitInfo[i])/attrFSplitInfo[labelnum];
		if(ps<0.000001)
			splitInfo += 0.0;
		else
			splitInfo += (-ps*log(ps)/log(2.0));

	}
	//printf(" splitInfo: %f\n", splitInfo);

	return ( (infoD-infoDA)/splitInfo );
}

void CreateDecisionTree( DesTreePoint *tree,const int datasize, const int attrinum, char **data,int *attriSome )
{
	double *ration = new double[attrinum] ;
	bool *lockAttr = new bool[attrinum];//lock the already check attribute column: 0,non-lock; 1,lock;
	int maxRation = 0;
	double maxV = -1;
	int classfiedLabel = -1;//if it's leaf point,classfiedLabel get the final classfied label

	//get root point
	for(int i=0;i<(attrinum-1);++i)
	{
		ration[i] = GainRatio(data,datasize,attrinum,0,NULL,attriSome,i,classfiedLabel);
		//printf("%d: %f \n",i,ration[i]);
	}

	for(int i=1;i<(attrinum-1);++i)
	{
		if(ration[i]>ration[maxRation])
			maxRation = i;
	}
	tree[0].prepos.attrName = -1;
	tree[0].curpos = maxRation;
	tree[0].deep = 0;
	tree[0].loca = 0;
	tree[0].label = '#';
	tree[0].preloca = -1;
	tree[0].nextPointNum = attriSome[maxRation];

	//print(tree[0]);

	for(int i=0;i<tree[0].nextPointNum;++i)
	{
		++locp;
		tree[0].nextloca[i] = locp;
		tree[locp].prepos.attrName = tree[0].curpos;
		tree[locp].loca = locp;
		tree[locp].prepos.attrValue = i+48;
		tree[locp].deep = tree[0].deep+1;
		tree[locp].preloca = tree[0].loca;
		tree[locp].label = '#';
		treeQueue.push_back(tree[locp]);
	}

	attrPoint tmpTree[maxClass];
	DesTreePoint tmpPoint;
	DesTreePoint staPoint;
	attrPoint tmpAttr;
	int tmpp = 0;
	while( !treeQueue.empty() )
	{
		//initial
		for( int i=0;i<attrinum;++i )
			lockAttr[i] = 0;

		tmpPoint = treeQueue.back();
		staPoint = tmpPoint;
		//print(tmpPoint);
		treeQueue.pop_back();
		tmpAttr = staPoint.prepos;
		tmpp = 0;
		while( tmpAttr.attrName!=-1 )
		{
			tmpTree[tmpp] = tmpAttr;
			lockAttr[tmpAttr.attrName] = 1;
			++tmpp;
			staPoint = tree[staPoint.preloca];
			tmpAttr = tree[staPoint.loca].prepos;
			//print(tmpPoint);
		}
		//printf(" equal to deep %d\n",tmpp );
		for(int i=0;i<(attrinum-1);++i)
			ration[i] = 0;

		bool isleaf = 0;
		for(int i=0;i<(attrinum-1);++i)
		{
			if( lockAttr[i]==1 )
				continue;
			//printf("%d: %f \n",i,ration[i]);
			ration[i] = GainRatio(data,datasize,attrinum,tmpPoint.deep,tmpTree,attriSome,i,classfiedLabel);
			//printf("%d: %f \n",i,ration[i]);
			if( ration[i]<0.000001 || tmpPoint.deep>maxDeep)
			{
				
				tree[tmpPoint.loca].label = classfiedLabel+48;
				isleaf = 1;
				i = attrinum;
			}	
		}

		classfiedLabel = -1; //go to the initial state
		maxRation = 0;
		
		if(!isleaf)
		{
			for(int i=1;i<(attrinum-1);++i)
			{
				if( lockAttr[i]==1 )
					continue;
				if(ration[i]>maxV)
				{
					maxRation = i;
					maxV = ration[i];
				}
			}
			
			tree[tmpPoint.loca].curpos = maxRation;
			
			tree[tmpPoint.loca].nextPointNum = attriSome[maxRation];
			//tree[tmpPoint.loca].label = '#';

			for(int i=0; i<tree[tmpPoint.loca].nextPointNum;++i)
			{
				++locp;
				tree[tmpPoint.loca].nextloca[i] = locp;
				tree[locp].prepos.attrName = tree[tmpPoint.loca].curpos;
				tree[locp].deep = tree[tmpPoint.loca].deep+1;
				tree[locp].prepos.attrValue = i+48;
				tree[locp].loca = locp;
				tree[locp].label = '#';
				tree[locp].preloca = tree[tmpPoint.loca].loca;
				treeQueue.push_back(tree[locp]);
			}
		}//if isleaf
	}//while

	delete []ration;
}

bool predict(DesTreePoint *tree,char *data, const int attrinum )
{
	int locp = 0;
	int tmploc = 0;
	char label = '#';
	DesTreePoint curPoint;

	curPoint = tree[0];
	if( tree[0].label!='#' )
	{
		if(tree[0].label==data[attrinum-1])
			return 1;
		else
			return 0;
	}
	while( curPoint.label=='#' )
	{
		locp = curPoint.loca;
		for(int i=0;i<tree[locp].nextPointNum;++i)
		{
			tmploc = tree[locp].nextloca[i];
			int attrName = tree[tmploc].prepos.attrName;
			if(data[attrName]==tree[tmploc].prepos.attrValue)
				curPoint = tree[tmploc];
		}
	}

	if(curPoint.label == data[attrinum-1])
		return 1;
	else
		return 0;
}
		
					
#endif

 <main.cpp>

/****************************

DATA: Car Evaluation Database(from UCI)

决策树学习算法C实现(彭杰)

Copyright 2015/3/21 owner by pengjie
All rights reserved

*******************************/


#include "main.h"

const int dataline = 1728;
const int attrinum = 7;
const int maxTree = 1000;

void SetMapAttr( StringCharMap *mapAttr );


int main()
{

	
	DesTreePoint *tree = new DesTreePoint[maxTree];
	//pre-input
	char *file = "C:\\Users\\Administrator\\Desktop\\machine_data\\car.data";
	int attriSome[attrinum] = {4,4,4,3,3,3,4};
	StringCharMap mapAttr[attrinum];
	SetMapAttr( mapAttr );

	int train = 1650; //the number of train data

	
	char **data = new char*[dataline];
	for(int i=0;i<dataline;++i)
		data[i] = new char[attrinum];

	if( -1!=ReadData(data,attriSome,mapAttr,file,dataline,attrinum) )
	{
		CreateDecisionTree( tree,train, attrinum, data,attriSome );
	}


	int correct = 0;
	int sum = 0;
	for(int i=(train+1);i<dataline; ++i)
	{
		++sum;
		bool eva = predict(tree,data[i],attrinum);
		if(eva)
			++correct;
	}

	double rp = double(correct)/sum;
	printf("the right correction: %f\n",rp);


	

	for(int i=0;i<dataline;++i)
		delete []data[i];
	delete []data;

	delete []tree;

	return 0;
}

void SetMapAttr( StringCharMap *mapAttr )
{
	mapAttr[0]["vhigh"] = '0';
	mapAttr[0]["high"] = '1';
	mapAttr[0]["med"] = '2';
	mapAttr[0]["low"] = '3';

	mapAttr[1]["vhigh"] = '0';
	mapAttr[1]["high"] = '1';
	mapAttr[1]["med"] = '2';
	mapAttr[1]["low"] = '3';

	mapAttr[2]["2"] = '0';
	mapAttr[2]["3"] = '1';
	mapAttr[2]["4"] = '2';
	mapAttr[2]["5more"] = '3';

	mapAttr[3]["2"] = '0';
	mapAttr[3]["4"] = '1';
	mapAttr[3]["more"] = '2';

	mapAttr[4]["small"] = '0';
	mapAttr[4]["med"] = '1';
	mapAttr[4]["big"] = '2';

	mapAttr[5]["low"] = '0';
	mapAttr[5]["med"] = '1';
	mapAttr[5]["high"] = '2';
	
	mapAttr[6]["unacc"] = '0';
	mapAttr[6]["acc"] = '1';
	mapAttr[6]["good"] = '2';
	mapAttr[6]["vgood"] = '3';
}
	


附件中添加了所用的数据及数据说明。数据来自UCI数据库。
实在没有看到在哪里上传数据,可以在以下资源里面下载所有相关文件。

http://download.youkuaiyun.com/detail/u200812705/8520817

 



 

 

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值