机器学习算法-决策树ID3

本文详细介绍了一种基于信息增益的决策树算法实现过程,包括算法基本思想、关键步骤及具体代码实现。

本文转载http://www.cnblogs.com/xiaoyi115/p/3696107.html

算法的基本思想可以概括为:

    1)树以代表训练样本的根结点开始。
    2)如果样本都在同一个类.则该结点成为树叶,并记录该类。
    3)否则,算法选择最有分类能力的属性作为决策树的当前结点.
    4 )根据当前决策结点属性取值的不同,将训练样本根据该属性的值分为若干子集,每个取值形成一个分枝,有几个取值形成几个分枝。匀针对上一步得到的一个子集,重复进行先前步骤,递归形成每个划分样本上的决策树。一旦一个属性只出现在一个结点上,就不必在该结点的任何后代考虑它,直接标记类别。
    5)递归划分步骤仅当下列条件之一成立时停止:
    ①给定结点的所有样本属于同一类。
    ②没有剩余属性可以用来进一步划分样本.在这种情况下.使用多数表决,将给定的结点转换成树叶,并以样本中元组个数最多的类别作为类别标记,同时也可以存放该结点样本的类别分布[这个主要可以用来剪枝]。
    ③如果某一分枝tc,没有满足该分支中已有分类的样本,则以样本的多数类生成叶子节点。
  算法中2)步所指的最优分类能力的属性。这个属性的选择是本算法种的关键点,分裂属性的选择直接关系到此算法的优劣。
   一般来说可以用比较信息增益和信息增益率的方式来进行。
  其中信息增益的概念又会牵扯出熵的概念。熵的概念是香农在研究信息量方面的提出的。它的计算公式是:
    Info(D)=-p1log(p1)/log(2.0)-p2log(p2)/log(2.0)-p3log(p3)/log(2.0)+...-pNlog(pN)/log(2.0)    (其中N表示所有的不同类别)
  而信息增益为:
            Gain(A)=Info(D)-Info(Da)             其中Info(Da)数据集在属性A的情况下的信息量(熵)。
代码参考http://blog.youkuaiyun.com/wuyanyi/article/details/7974775/

三个文件:

测试数据:data.txt

[plain]  view plain copy
  1. D1    Sunny        Hot    High        Weak    No  
  2. D2    Sunny        Hot    High        Strong    No  
  3. D3    Overcast    Hot    High        Weak    Yes  
  4. D4    Rain        Mild    High        Weak    Yes  
  5. D5    Rain        Cool    Normal        Weak    Yes  
  6. D6    Rain        Cool    Normal        Strong    No  
  7. D7    Overcast    Cool    Normal        Strong    Yes  
  8. D8    Sunny        Mild    High        Weak    No  
  9. D9    Sunny        Cool    Normal        Weak    Yes  
  10. D10    Rain        Mild    Normal        Weak    Yes  
  11. D11    Sunny        Mild    Normal        Strong    Yes  
  12. D12    Overcast    Mild    High        Strong    Yes  
  13. D13    Overcast    Hot    Normal        Weak    Yes  
  14. D14    Rain        Mild    High        Strong    No  


程序头文件:id3.h
[cpp]  view plain copy
  1. #ifndef ID3_H  
  2. #define ID3_H  
  3. #include<fstream>  
  4. #include<iostream>  
  5. #include<vector>  
  6. #include<map>  
  7. #include<set>  
  8. #include<cmath>  
  9. using namespace std;  
  10. const int DataRow=14;  
  11. const int DataColumn=6;  
  12. struct Node  
  13. {  
  14.     double value;//代表此时yes的概率。  
  15.     int attrid;  
  16.     Node * parentNode;  
  17.     vector<Node*> childNode;  
  18. };  
  19. #endif  

程序源文件id3.cpp

[cpp]  view plain copy
  1. #include "id3.h"  
  2.   
  3. string DataTable[DataRow][DataColumn];  
  4. map<string,int> str2int;  
  5. set<int> S;  
  6. set<int> Attributes;  
  7. string attrName[DataColumn]={"Day","Outlook","Temperature","Humidity","Wind","PlayTennis"};  
  8. string attrValue[DataColumn][DataRow]=  
  9. {  
  10.     {},//D1,D2这个属性不需要  
  11.     {"Sunny","Overcast","Rain"},  
  12.     {"Hot","Mild","Cool"},  
  13.     {"High","Normal"},  
  14.     {"Weak","Strong"},  
  15.     {"No","Yes"}  
  16. };  
  17. int attrCount[DataColumn]={14,3,3,2,2,2};  
  18. double lg2(double n)  
  19. {  
  20.     return log(n)/log(2);  
  21. }  
  22. void Init()  
  23. {  
  24.     ifstream fin("data.txt");  
  25.     for(int i=0;i<14;i++)  
  26.     {  
  27.       for(int j=0;j<6;j++)  
  28.       {  
  29.           fin>>DataTable[i][j];  
  30.       }  
  31.     }  
  32.     fin.close();  
  33.     for(int i=1;i<=5;i++)  
  34.     {  
  35.         str2int[attrName[i]]=i;  
  36.         for(int j=0;j<attrCount[i];j++)  
  37.         {  
  38.             str2int[attrValue[i][j]]=j;  
  39.         }  
  40.     }  
  41.     for(int i=0;i<DataRow;i++)  
  42.       S.insert(i);  
  43.     for(int i=1;i<=4;i++)  
  44.       Attributes.insert(i);  
  45. }  
  46.   
  47. double Entropy(const set<int> &s)  
  48. {  
  49.     double yes=0,no=0,sum=s.size(),ans=0;  
  50.     for(set<int>::iterator it=s.begin();it!=s.end();it++)  
  51.     {  
  52.         string s=DataTable[*it][str2int["PlayTennis"]];  
  53.         if(s=="Yes")  
  54.           yes++;  
  55.         else  
  56.           no++;  
  57.     }  
  58.     if(no==0||yes==0)  
  59.       return ans=0;  
  60.     ans=-yes/sum*lg2(yes/sum)-no/sum*lg2(no/sum);  
  61.     return ans;  
  62. }  
  63. double Gain(const set<int> & example,int attrid)  
  64. {  
  65.     int attrcount=attrCount[attrid];  
  66.     double ans=Entropy(example);  
  67.     double sum=example.size();  
  68.     set<int> * pset=new set<int>[attrcount];  
  69.     for(set<int>::iterator it=example.begin();it!=example.end();it++)  
  70.     {  
  71.         pset[str2int[DataTable[*it][attrid]]].insert(*it);  
  72.     }  
  73.     for(int i=0;i<attrcount;i++)  
  74.     {  
  75.         ans-=pset[i].size()/sum*Entropy(pset[i]);  
  76.     }  
  77.     return ans;  
  78. }  
  79. int FindBestAttribute(const set<int> & example,const set<int> & attr)  
  80. {  
  81.     double mx=0;  
  82.     int k=-1;  
  83.     for(set<int>::iterator i=attr.begin();i!=attr.end();i++)  
  84.     {  
  85.         double ret=Gain(example,*i);  
  86.         if(ret>mx)  
  87.         {  
  88.             mx=ret;  
  89.             k=*i;  
  90.         }  
  91.     }  
  92.     if(k==-1)  
  93.       cout<<"FindBestAttribute error!"<<endl;  
  94.     return k;  
  95. }  
  96. Node * Id3_solution(set<int> example,set<int> & attributes,Node * parent)  
  97. {  
  98.     Node *now=new Node;//创建树节点。  
  99.     now->parentNode=parent;  
  100.     if(attributes.empty())//如果此时属性列表已用完,即为空,则返回。  
  101.       return now;  
  102.   
  103.     /* 
  104.      * 统计一下example,如果都为正或者都为负则表示已经抵达决策树的叶子节点 
  105.      * 叶子节点的特征是有childNode为空。 
  106.      */  
  107.     int yes=0,no=0,sum=example.size();  
  108.     for(set<int>::iterator it=example.begin();it!=example.end();it++)  
  109.     {  
  110.         string s=DataTable[*it][str2int["PlayTennis"]];  
  111.         if(s=="Yes")  
  112.           yes++;  
  113.         else  
  114.           no++;  
  115.     }  
  116.     if(yes==sum||yes==0)  
  117.     {  
  118.         now->value=yes/sum;  
  119.         return now;  
  120.     }  
  121.       
  122.   
  123.     /*找到最高信息增益的属性并将该属性从attributes集合中删除*/  
  124.     int bestattrid=FindBestAttribute(example,attributes);  
  125.     now->attrid=bestattrid;  
  126.     attributes.erase(attributes.find(bestattrid));  
  127.       
  128.     /*将exmple根据最佳属性的不同属性值分成几个分支,每个分支有即一个子树*/  
  129.     vector< set<int> > child=vector< set<int> >(attrCount[bestattrid]);  
  130.     for(set<int>::iterator i=example.begin();i!=example.end();i++)  
  131.     {  
  132.         int id=str2int[DataTable[*i][bestattrid]];  
  133.         child[id].insert(*i);  
  134.     }  
  135.     for(int i=0;i<child.size();i++)  
  136.     {  
  137.         Node * ret=Id3_solution(child[i],attributes,now);  
  138.         now->childNode.push_back(ret);  
  139.     }  
  140.     return now;  
  141. }  
  142.   
  143. int main()  
  144. {  
  145.     Init();  
  146.     Node * Root=Id3_solution(S,Attributes,NULL);  
  147.     return 0;  
  148. }  

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值