本文转载http://www.cnblogs.com/xiaoyi115/p/3696107.html
算法的基本思想可以概括为:
1)树以代表训练样本的根结点开始。
2)如果样本都在同一个类.则该结点成为树叶,并记录该类。
3)否则,算法选择最有分类能力的属性作为决策树的当前结点.
4 )根据当前决策结点属性取值的不同,将训练样本根据该属性的值分为若干子集,每个取值形成一个分枝,有几个取值形成几个分枝。匀针对上一步得到的一个子集,重复进行先前步骤,递归形成每个划分样本上的决策树。一旦一个属性只出现在一个结点上,就不必在该结点的任何后代考虑它,直接标记类别。
5)递归划分步骤仅当下列条件之一成立时停止:
①给定结点的所有样本属于同一类。
②没有剩余属性可以用来进一步划分样本.在这种情况下.使用多数表决,将给定的结点转换成树叶,并以样本中元组个数最多的类别作为类别标记,同时也可以存放该结点样本的类别分布[这个主要可以用来剪枝]。
③如果某一分枝tc,没有满足该分支中已有分类的样本,则以样本的多数类生成叶子节点。
算法中2)步所指的最优分类能力的属性。这个属性的选择是本算法种的关键点,分裂属性的选择直接关系到此算法的优劣。
一般来说可以用比较信息增益和信息增益率的方式来进行。
其中信息增益的概念又会牵扯出熵的概念。熵的概念是香农在研究信息量方面的提出的。它的计算公式是:
Info(D)=-p1log(p1)/log(2.0)-p2log(p2)/log(2.0)-p3log(p3)/log(2.0)+...-pNlog(pN)/log(2.0) (其中N表示所有的不同类别)
而信息增益为:
Gain(A)=Info(D)-Info(Da) 其中Info(Da)数据集在属性A的情况下的信息量(熵)。
代码参考http://blog.youkuaiyun.com/wuyanyi/article/details/7974775/
三个文件:
测试数据:data.txt
- D1 Sunny Hot High Weak No
- D2 Sunny Hot High Strong No
- D3 Overcast Hot High Weak Yes
- D4 Rain Mild High Weak Yes
- D5 Rain Cool Normal Weak Yes
- D6 Rain Cool Normal Strong No
- D7 Overcast Cool Normal Strong Yes
- D8 Sunny Mild High Weak No
- D9 Sunny Cool Normal Weak Yes
- D10 Rain Mild Normal Weak Yes
- D11 Sunny Mild Normal Strong Yes
- D12 Overcast Mild High Strong Yes
- D13 Overcast Hot Normal Weak Yes
- D14 Rain Mild High Strong No
- #ifndef ID3_H
- #define ID3_H
- #include<fstream>
- #include<iostream>
- #include<vector>
- #include<map>
- #include<set>
- #include<cmath>
- using namespace std;
- const int DataRow=14;
- const int DataColumn=6;
- struct Node
- {
- double value;//代表此时yes的概率。
- int attrid;
- Node * parentNode;
- vector<Node*> childNode;
- };
- #endif
程序源文件id3.cpp
- #include "id3.h"
- string DataTable[DataRow][DataColumn];
- map<string,int> str2int;
- set<int> S;
- set<int> Attributes;
- string attrName[DataColumn]={"Day","Outlook","Temperature","Humidity","Wind","PlayTennis"};
- string attrValue[DataColumn][DataRow]=
- {
- {},//D1,D2这个属性不需要
- {"Sunny","Overcast","Rain"},
- {"Hot","Mild","Cool"},
- {"High","Normal"},
- {"Weak","Strong"},
- {"No","Yes"}
- };
- int attrCount[DataColumn]={14,3,3,2,2,2};
- double lg2(double n)
- {
- return log(n)/log(2);
- }
- void Init()
- {
- ifstream fin("data.txt");
- for(int i=0;i<14;i++)
- {
- for(int j=0;j<6;j++)
- {
- fin>>DataTable[i][j];
- }
- }
- fin.close();
- for(int i=1;i<=5;i++)
- {
- str2int[attrName[i]]=i;
- for(int j=0;j<attrCount[i];j++)
- {
- str2int[attrValue[i][j]]=j;
- }
- }
- for(int i=0;i<DataRow;i++)
- S.insert(i);
- for(int i=1;i<=4;i++)
- Attributes.insert(i);
- }
- double Entropy(const set<int> &s)
- {
- double yes=0,no=0,sum=s.size(),ans=0;
- for(set<int>::iterator it=s.begin();it!=s.end();it++)
- {
- string s=DataTable[*it][str2int["PlayTennis"]];
- if(s=="Yes")
- yes++;
- else
- no++;
- }
- if(no==0||yes==0)
- return ans=0;
- ans=-yes/sum*lg2(yes/sum)-no/sum*lg2(no/sum);
- return ans;
- }
- double Gain(const set<int> & example,int attrid)
- {
- int attrcount=attrCount[attrid];
- double ans=Entropy(example);
- double sum=example.size();
- set<int> * pset=new set<int>[attrcount];
- for(set<int>::iterator it=example.begin();it!=example.end();it++)
- {
- pset[str2int[DataTable[*it][attrid]]].insert(*it);
- }
- for(int i=0;i<attrcount;i++)
- {
- ans-=pset[i].size()/sum*Entropy(pset[i]);
- }
- return ans;
- }
- int FindBestAttribute(const set<int> & example,const set<int> & attr)
- {
- double mx=0;
- int k=-1;
- for(set<int>::iterator i=attr.begin();i!=attr.end();i++)
- {
- double ret=Gain(example,*i);
- if(ret>mx)
- {
- mx=ret;
- k=*i;
- }
- }
- if(k==-1)
- cout<<"FindBestAttribute error!"<<endl;
- return k;
- }
- Node * Id3_solution(set<int> example,set<int> & attributes,Node * parent)
- {
- Node *now=new Node;//创建树节点。
- now->parentNode=parent;
- if(attributes.empty())//如果此时属性列表已用完,即为空,则返回。
- return now;
- /*
- * 统计一下example,如果都为正或者都为负则表示已经抵达决策树的叶子节点
- * 叶子节点的特征是有childNode为空。
- */
- int yes=0,no=0,sum=example.size();
- for(set<int>::iterator it=example.begin();it!=example.end();it++)
- {
- string s=DataTable[*it][str2int["PlayTennis"]];
- if(s=="Yes")
- yes++;
- else
- no++;
- }
- if(yes==sum||yes==0)
- {
- now->value=yes/sum;
- return now;
- }
- /*找到最高信息增益的属性并将该属性从attributes集合中删除*/
- int bestattrid=FindBestAttribute(example,attributes);
- now->attrid=bestattrid;
- attributes.erase(attributes.find(bestattrid));
- /*将exmple根据最佳属性的不同属性值分成几个分支,每个分支有即一个子树*/
- vector< set<int> > child=vector< set<int> >(attrCount[bestattrid]);
- for(set<int>::iterator i=example.begin();i!=example.end();i++)
- {
- int id=str2int[DataTable[*i][bestattrid]];
- child[id].insert(*i);
- }
- for(int i=0;i<child.size();i++)
- {
- Node * ret=Id3_solution(child[i],attributes,now);
- now->childNode.push_back(ret);
- }
- return now;
- }
- int main()
- {
- Init();
- Node * Root=Id3_solution(S,Attributes,NULL);
- return 0;
- }
本文详细介绍了一种基于信息增益的决策树算法实现过程,包括算法基本思想、关键步骤及具体代码实现。
1844

被折叠的 条评论
为什么被折叠?



