“集体智慧编程”之第七章:决策树

什么是决策树?

如果将决策树和上一章的分类器一起讲述,那么决策树这种算法也是用于对物品分类的,书有一个非常简单的例子,能帮助我理解什么是决策树。
给你一个水果,你可以通过以下方式判断出这是一个什么水果。

可以看出,决策树上就是一个又一个if-then的语句联系起来的。而且从最终结果,我们也能够看出整个推理的过程。而上一章讲述的贝叶斯分类器里每一个单词的重要性通过计算而得到的。

实例


背景


书中还是通过实例来带领我们学习。背景:一个网站,网站的功能被分为了基本或高级,网站被很多用户使用了一段时间。用户在使用过程中,我们统计了一下信息:

  • 我们想知道哪些用户可能成为付费用户。
显然,如果高级功能(premium)是付费用户专属,那么我们就想知道哪种用户(来自哪里的,ip是哪的,浏览网页多少次的)更会使用高级功能,更可能成为我们的付费用户。比如:上图中,所有来自网站为google的用户都会使用premium高级功能,这样了,我们就知道应该加大在google上做广告的力度,因为google那边过来的用户更有可能使用高级功能,更可能成为我们的会员。
由此,注意两点:
  1. 使用决策树的时候,我们必然需要一些原始的数据对决策树进行训练,这些数据都包含了用户最终的选择、或者结果。也就是需要有输入和输出的数据。之前的博客也说过决策树属于监督类学习。
  2. 虽然在上面的例子中,我们能够直观的看出来自google的用户,更容易成为会员(看起来我们就已经找到了成为会员的关键)。但是实际情况不是那么简单,请继续看本博客就知道了。此外,如果原始数据特别多的时候,用肉眼观察是非常辛苦的。

数据集


书中为我们准备了数据集,就是上面那幅表的内容,用列表数组表示:
[python]  view plain  copy
  1. my_data=[['slashdot','USA','yes',18,'None'],  
  2.         ['google','France','yes',23,'Premium'],  
  3.         ['digg','USA','yes',24,'Basic'],  
  4.         ['kiwitobes','France','yes',23,'Basic'],  
  5.         ['google','UK','no',21,'Premium'],  
  6.         ['(direct)','New Zealand','no',12,'None'],  
  7.         ['(direct)','UK','no',21,'Basic'],  
  8.         ['google','USA','no',24,'Premium'],  
  9.         ['slashdot','France','yes',19,'None'],  
  10.         ['digg','USA','no',18,'None'],  
  11.         ['google','UK','no',18,'None'],  
  12.         ['kiwitobes','UK','no',19,'None'],  
  13.         ['digg','New Zealand','yes',12,'Basic'],  
  14.         ['slashdot','UK','no',21,'None'],  
  15.         ['google','UK','yes',18,'Basic'],  
  16.         ['kiwitobes','France','yes',19,'Basic']]  


树中节点的表示


就像上面那个水果的例子一样,决策树实际上是由一个又有一个的节点组成的。在代码中,我们用一个类来表示:
[python]  view plain  copy
  1. #这是决策树的表达形式:一个一个的节点。每一个节点有五个属性。  
  2. class decisionnode:  
  3.     def __init__(self,col=-1,value=None,results=None,tb=None,fb=None):  
  4.         self.col=col #被判断条件的对应的列的索引号,如图中:来源网站,是否阅读过FAQ的列序号  
  5.         self.value=value#什么情况下被判定为真?就看这个value的值,如果value为Yes(对应是否阅读过FAQ),也是Yes的时候为真。如果是大于20(对应浏览页数)那就是大于20为真  
  6.         self.results=results#最终结果,只有叶节点才有这个值,分类结果或者判定的结果  
  7.         self.tb=tb#也是一个decisionnode,t代表true,为true的话,就走这个节点。叶节点,没有这个,但是有结果  
  8.         self.fb=fb#也是一个decisionnode,t代表false,为false的话,就走这个节点。叶节点,没有这个,但是有结果  

构建树

我们分类回归树的算法(CART,Classification And Regression Tree)来构建树。该算法的思想是:首先创建一个根节点,然后我们选择所有统计到里面的数据的一个(就是中的一个)来对初始数据集进行划分。比如,根节点处,我们选择是否阅读过FAQ来对节点进行拆分,就可以拆分成"看过"和”没看过”的两个数据集,抽取两个数据集的最后的一列,也就是我们关心的用户使用的功能,由此可知是否是有可能成为付费会员。如下图所示:

所以,根据我们这个需求,代码我们写了如下函数:

分解节点

[python]  view plain  copy
  1. #根据某一列,对数据进行拆分成两个set,一个set代表选true的时候,一个set代表选false的时候  
  2. def divideset(rows,column,value):  
  3.     #定义了一个新函数,用这个函数去判断每一行数据是属于第一组(true),还是第二组(false)  
  4.     split_function=None  
  5.     #根据value的值,如果value是数字的话,一般都是大于某个数,如果value是布尔的话,那就是为true。  
  6.     #为了使这个函数既能够,又能够接受数值类的判断,又能够结果布尔值,是与否的判断,才如此的。  
  7.     if isinstance(value,int) or isinstance(value,float):  
  8.         split_function=lambda row:row[column]>=value#lambda创建一个新函数,该函数的接受的参数为row,函数内容为row[column]>=value  
  9.     else:  
  10.         split_function=lambda row:row[column]==value  
  11.   
  12.     #将数据集根据上面的函数,以及为真条件,判断,并返回  
  13.     set1=[row for row in rows if split_function(row)]#用split_function函数判断一下,如果成功就是放在set1  
  14.     set2=[row for row in rows if not split_function(row)]#用split_function函数判断一下,如果失败就是放在set2  
  15.     return (set1,set2)  

执行代码:
[python]  view plain  copy
  1. set1,set2=divideset(my_data,2,'yes')  
  2. print set1  
  3. print set2  


结果:
[python]  view plain  copy
  1. >>>   
  2. [['slashdot''USA''yes'18'None'], ['google''France''yes'23'Premium'], ['digg''USA''yes'24'Basic'], ['kiwitobes''France''yes'23'Basic'], ['slashdot''France''yes'19'None'], ['digg''New Zealand''yes'12'Basic'], ['google''UK''yes'18'Basic'], ['kiwitobes''France''yes'19'Basic']]  
  3. [['google''UK''no'21'Premium'], ['(direct)''New Zealand''no'12'None'], ['(direct)''UK''no'21'Basic'], ['google''USA''no'24'Premium'], ['digg''USA''no'18'None'], ['google''UK''no'18'None'], ['kiwitobes''UK''no'19'None'], ['slashdot''UK''no'21'None']]  
  4. >>>   

选择哪个作为拆分依据?

假设我们选择是否阅读过FAQ来拆分数据集,会得到下图:

然而,对于使用是否阅读过FAQ划分是非常不好的 。因为选择了yes的,和选择了no的,用户使用的功能已经混杂了。这就是混合程度。
比如,如果我们使用第一列的来源来拆分,如果来源是google,就都是使用高级功能的用户,如果来源不是google,就是多种功能混杂了。那么我们认为,显然如果选择原来是google比是否阅读过FAQ好的多。
所以,我们必须要有一套机制,来选择每次拆分时,拆分哪一个最能使不同的功能分开,也就是减少混杂度。
为了完成上面的功能,我们首先要写一个统计有哪些结果(本例中)就是哪些功能(比如1个高级功能,2个基本功能,3个None)的函数:
[python]  view plain  copy
  1. #对可能产生的最终判定结果做一个统计,一般来说,最后一列就是最终判定结果,比如该用户是使用基本功能还是高级功能,还是没什么需求  
  2. #不仅要统计有什么结果,还要统计出现的次数  
  3. def uniquecounts(rows):  
  4.     results={}  
  5.     for row in rows:  
  6.         #计数结果一般在最后一列  
  7.         r=row[len(row)-1]  
  8.         if r not in results:results[r]=0#不存在就新加一列  
  9.         results[r]+=1#存在就+1  
  10.     return results  

计算混杂程度:基尼不纯度

对于基尼不纯度原理的理解, wiki百科的解释比较好理解,以下摘自维基百科:
在CART算法中, 基尼不纯度表示一个随机选中的样本在子集中被分错的可能性。基尼不纯度为这个样本被选中的概率乘以它被分错的概率。当一个节点中所有样本都是一个类时,基尼不纯度为零。
假设y的可能取值为{1, 2, ..., m},令fi是样本被赋予i的概率,则基尼指数可以通过如下计算:

代码如下:
[python]  view plain  copy
  1. #函数接受一个数据集,然后计算其混杂程度。  
  2. #将这样的数学思维转为成代码是一件非常困难的事。  
  3. #利用集合中每一项结果出现的次数除以集合的总行数计算出该结果的概率  
  4. #出现k1的概率和不是k1的时候(k2)概率相乘,再依次把所有的这种的情况相加  
  5. #就可以得到:某一行数据被随机分配到错误结果的总概率  
  6. def giniimpurity(rows):  
  7.     total=len(rows)  
  8.     counts=uniquecounts(rows)  
  9.     imp=0  
  10.     for k1 in counts:  
  11.         p1=float(counts[k1])/total  
  12.         for k2 in counts:  
  13.             if k1==k2:continue  
  14.             p2=float(counts[k2])/total  
  15.             imp+=p1*p2  
  16.     return imp #返回值越高,表示越容易被分到其他类,也就越混杂,那么0代表拆分结果最为理想  


计算混杂程度:熵

熵是指集合无序的程度。熵可以由如下方式计算得出:
首先计算出每一项数据出现的频率(即数据项出现的次数除以集合的总行数),再使用如下公式:


代码如下:
[python]  view plain  copy
  1. #函数接受一个数据集,然后计算其混杂程度。使用熵来计算  
  2. #熵遍历所有可能的结果的概率除以总行数的概率p,然后将所有的p做计算:p*log(p),再将所有的这个结果加起来  
  3. def entropy(rows):  
  4.     from math import log  
  5.     log2=lambda x:log(x)/log(2)  
  6.     results=uniquecounts(rows)  
  7.     #计算熵  
  8.     ent=0.0  
  9.     for r in results.keys():  
  10.         p=float(results[r])/len(rows)  
  11.         ent=ent-p*log2(p)#看样子熵算出来是一个负数  
  12.     return ent#熵越大,混乱度越高,如此,一个集合都的结果都一样的话,那么熵应该为0  

执行代码:
[python]  view plain  copy
  1. print giniimpurity(my_data)  
  2. print entropy(my_data)  

结果:
[python]  view plain  copy
  1. >>>   
  2. 0.6328125  
  3. 1.50524081494  
  4. >>>   

基尼不纯度和熵

我们之所要计算混杂程度,就是为了最大限度降低拆分的两个集合的混杂程度。比如根据某一列一下拆分,就能拆分出使用高级功能和基本功能的用户,那么混杂程度最低,我们就找到了解决问题的关键。 我们每一次的拆分,就是为了降低熵,能为0最好,不能也要不断想办法降低。
书中说道:基尼不纯度和熵的最大区别在于,对于混杂的程度,熵”判罚“的更重一些,人们对熵使用更为普遍。

递归构建树

递归构造树的思维非常重要。其次,我们来构造树的整个过程:
  1. 算出根节点的熵
  2. 依次以每一列的不同结果来划分数据集
  3. 计算划分出来的两个数据集的熵
  4. 算出信息增益:根节点的熵和两个数据集经过加权平均后的熵的差值
  5. 比较以每一列的不同结果划分的而产生的信息增益
  6. 选择出信息增值最大的那一列的结果,作为根节点的划分依据(也就是熵减少最多的)
  7. 循环对将两个数据集分别作为根节点,重复从1步骤开始的过程
  8. 当信息增益不再增大后,停止。树构建完毕
代码如下:
[python]  view plain  copy
  1. def buildtree(rows,scoref=entropy):  
  2.     if len(rows)==0:return decisionnode()#就是一个空节点呗  
  3.     current_score=scoref(rows)  
  4.   
  5.     #定义一些变量以方便记录最佳的拆分的条件  
  6.     best_gain=0.0  
  7.     best_criteria=None#标准,准则  
  8.     best_sets=None  
  9.     #最后一列是用来存放结果的,本例中就是用户使用了高级功能、基本功能、没有,所以在选最佳属性的时候会忽略掉这一列  
  10.     column_count=len(rows[0])-1  
  11.     for col in range(0,column_count):  
  12.         #在当前列中,形成一个不同值构成的序列,也就是说这一列有多少种可能的取值  
  13.         column_values={}  
  14.         for row in rows:  
  15.             column_values[row[col]]=1#好像这样可以去重,比如某一列有两个yes的话,但是最终最会一个yes在集合中  
  16.         #对这一列中的每一个词,都尝试一次数据的拆分  
  17.         for value in column_values.keys():  
  18.             (set1,set2)=divideset(rows,col,value)  
  19.   
  20.             #计算信息增益  
  21.             p=float(len(set1))/len(rows)#计算出set1的权重,也就是set1的行数除以总行数  
  22.             gain=current_score-p*scoref(set1)-(1-p)*scoref(set2)  
  23.             if gain>best_gain and len(set1)>0 and len(set2)>0:  
  24.                 best_gain=gain  
  25.                 best_criteria=(col,value)  
  26.                 best_sets=(set1,set2)  
  27.     #创建子分支  
  28.     if best_gain>0:#大于0表示还可以创建下面的分支  
  29.         trueBranch=buildtree(best_sets[0])  
  30.         falseBranch=buildtree(best_sets[1])  
  31.         return decisionnode(col=best_criteria[0],value=best_criteria[1],tb=trueBranch,fb=falseBranch)  
  32.     else:#不大于0,就是等于0,那么就这个集合不用再划分子集合了,就是叶节点,叶节点带有结果。  
  33.         return decisionnode(results=uniquecounts(rows))  

再总结一下代码的思路(因为非常重要):


上述函数接受一个初始数据的列表作为参数。然后遍历数据集中的每一列的每一个结果。针对用每一个结果,产生将数据集拆分为两个子集,计算数据集和子集的信息增益,找到信息增值最大的那个结果。如果信息增益为0,就结束拆分,并且记录了最后的结果和次数(就是高级功能还是基本功能)。在子集上还会调用这个函数,并把得到的结果加在这颗树上相应的True分支和False分支。

命令行显示决策树:

代码:

[python]  view plain  copy
  1. def printtree(tree,indent=''):  
  2.     #这是一个叶节点吗?  
  3.     if tree.results!=None:  
  4.         print str(tree.results)  
  5.     else:  
  6.         #打印判断条件  
  7.         print str(tree.col)+':'+str(tree.value)+'? '  
  8.   
  9.   
  10.         #打印分支  
  11.         print indent+'T->',#print语句默认的会在后面加上 换行,加了逗号之后 换行 就变成了 空格  
  12.         printtree(tree.tb,indent+' ')  
  13.         print indent+'F->',#print语句默认的会在后面加上 换行,加了逗号之后 换行 就变成了 空格  
  14.         printtree(tree.fb,indent+' ')  

执行代码:

[python]  view plain  copy
  1. tree=buildtree(my_data)  
  2. printtree(tree)  

结果:

[python]  view plain  copy
  1. >>>   
  2. 0:google?   
  3. T-> 3:21?   
  4.  T-> {'Premium'3}  
  5.  F-> 2:yes?   
  6.   T-> {'Basic'1}  
  7.   F-> {'None'1}  
  8. F-> 0:slashdot?   
  9.  T-> {'None'3}  
  10.  F-> 2:yes?   
  11.   T-> {'Basic'4}  
  12.   F-> 3:21?   
  13.    T-> {'Basic'1}  
  14.    F-> {'None'3}  
  15. >>>   

图形显示树:

代码:
[python]  view plain  copy
  1. ef getwidth(tree):  
  2.     if tree.tb==None and tree.fb==Nonereturn 1  
  3.     return getwidth(tree.tb)+getwidth(tree.fb)# 统计了有多少个子节点  
  4. def getdepth(tree):  
  5.     if tree.tb==None and tree.fb==None:return 0  
  6.     return max(getdepth(tree.tb),getdepth(tree.fb))+1#每多一层就会加1  
  7. from PIL import Image,ImageDraw  
  8. def drawtree(tree,jpeg='tree.jpg'):  
  9.     w=getwidth(tree)*100+120  
  10.     h=getdepth(tree)*100+120  
  11.   
  12.     img=Image.new('RGB',(w,h),(255,255,255))  
  13.     draw=ImageDraw.Draw(img)  
  14.   
  15.     drawnode(draw,tree,w/2,20)  
  16.     img.save(jpeg,'JPEG')  
  17.   
  18. def drawnode(draw,tree,x,y):  
  19.     if tree.results==None:  
  20.         #得到每个分支的宽度  
  21.         w1=getwidth(tree.fb)*100  
  22.         w2=getwidth(tree.tb)*100  
  23.   
  24.         #确定此节点所要占据的总空间  
  25.         left=x-(w1+w2)/2#确定左边界  
  26.         right=x+(w1+w2)/2#确定右边界  
  27.   
  28.         #绘制判断条件的字符串  
  29.         draw.text((x-20,y-10),str(tree.col)+':'+str(tree.value),(0,0,0))  
  30.   
  31.         #绘制到分支的连线  
  32.         draw.line((x,y,left+w1/2,y+100),fill=(255,0,0))  
  33.         draw.line((x,y,right-w2/2,y+100),fill=(255,0,0))  
  34.   
  35.         #绘制分支的节点  
  36.           
  37.         drawnode(draw,tree.fb,left+w1/2,y+100)  
  38.         drawnode(draw,tree.tb,right-w2/2,y+100)  
  39.     else:  
  40.         txt=' \n'.join(['%s:%d' %v for v in tree.results.items()])#results是个字典,items方法将所有的字典项以列表的方式返回,列表中的项由(key,value)组成,返回项无特殊顺序。  
  41.         draw.text(((x-20),y),txt,(0,0,0))  

结果:


预测新数据分类:

来了一个新用户,我们收集了一些它的信息,然后我们想使用决策树来判断一下这个用户会使用什么样的功能(高级还是基本,还是None)。实际上,我们看着图,比对着这个用户的数据,我们自己找都能找出来,这里就是为了使用程序快捷方便的告诉我们。

代码如下:

[python]  view plain  copy
  1. #接受一个新的数据,然后让其决策树对其分类。  
  2. def classify(observation,tree):  
  3.     if tree.results!=None:  
  4.         return tree.results  
  5.     else:  
  6.         v=observation[tree.col]#拿到需要判断的那一列的数值  
  7.         branch=None  
  8.         if isinstance(v,int) or isinstance(v,float):  
  9.             if v>=tree.value:branch=tree.tb  
  10.             else: branch=tree.fb  
  11.         else:  
  12.             if v==tree.value:branch=tree.tb  
  13.             else: branch=tree.fb  
  14.         return classify(observation,branch)  


执行代码:

[python]  view plain  copy
  1. tree=buildtree(my_data)  
  2. print classify(['(direct)','USA','yes',5],tree)  

结构:
[python]  view plain  copy
  1. >>>   
  2. {'Basic'4}  
  3. >>>   


那如果新数据里面,有缺失的情况怎么办?也就是说,比如,我们未能通过追踪Ip查询到用户的所在地。所以第二列,用户所在地,就没有。我们也不希望我们的预测能够失效。此时,我们的应对方式,如果缺失的列,是必须要经过判断的地方,也就是分支处,那么我们选择两个分支都走。但是在返回给用户的最终结果上面,我们会给加上一个权重,就是该分支的数据占所有数据的比例。

[python]  view plain  copy
  1. #接受一个新的数据,然后让其决策树对其分类。  
  2. #该函数可以接受该新数据中缺失了需要判断的数据,返回的结果中会给出最终类型的概率  
  3. def mdclassify(observation,tree):  
  4.     if tree.results!=None:  
  5.         return tree.results  
  6.     else:  
  7.         v=observation[tree.col]#tree.col是当前节点需要判断的值,v是取出了需要被分类的数据的具体的值  
  8.         if v==None:#如果需要判断的值缺失  
  9.             tr,fr=mdclassify(observation,tree.tb),mdclassify(observation,tree.fb)  
  10.             tcount=sum(tr.values())#tr,fr是一个字典,而其中values()是用一个列表返回所有的字典中的键值对的值。  
  11.             fcount=sum(fr.values())  
  12.             tw=float(tcount)/(tcount+fcount)#这是一个权重,而这个权重的某一结果的行数占全部行数的比例  
  13.             fw=float(fcount)/(tcount+fcount)  
  14.             result={}  
  15.             for k,v in tr.items():result[k]=v*tw  
  16.             for k,v in fr.items():  
  17.                 if k not in result:result[k]=0  
  18.                 result[k] +=v*fw  
  19.             return result  
  20.         else:#如果需要判断的值不缺失的话  
  21.             if isinstance(v,int) or isinstance(v,float):#如果需要判断的数是数值型,那么就是大于或者小于  
  22.                 if v>=tree.value:branch=tree.tb  
  23.                 else: branch=tree.fb  
  24.             else:#如果需要判断的数是布尔型,,那么就是是或者否  
  25.                 if v==tree.value:branch=tree.tb  
  26.                 else:branch=tree.fb  
  27.             return mdclassify(observation,branch)  

执行代码:
[python]  view plain  copy
  1. tree=buildtree(my_data)  
  2. print mdclassify(['google','France',None,None],tree)  

得到的结果:
[python]  view plain  copy
  1. >>>   
  2. {'None'0.125'Premium'2.25'Basic'0.125}  
  3. >>>   
由此可以看出,当给上['google','France',None,None]这样的用户数据,我们可以得到其成为Premium的可能性最大。


剪枝:

所谓剪枝,就是指,有些支点拆成两个,也不会使熵降低太多。所以,我们将熵不会降低特别多的枝节点的两个叶节点合并为一个叶节点。
那为什么我们在建造树的过程中,就让降低熵加大一些呢?这样就可以提早结束拆分,但是这样会缺乏考虑一种情况,就是虽然这一次分了之后的熵没降低多少,但是下一次会降低非常多呀。所以,我们采用剪枝的方式,而不是提早结束。
代码如下:
[python]  view plain  copy
  1. #剪枝函数,基本思想是,判断某一个枝节点的两个叶节点能否合并  
  2. #依据:合并后的熵只有微弱的增加,增加的程度小于mingain。也就是把这个枝节点拆了熵也降低不了多少  
  3. #mingain人为定的阈值  
  4. def prune(tree,mingain):  
  5.     #如果该分支不是叶节点,则对其进行剪枝操作  
  6.     if tree.tb.results==None:  
  7.         prune(tree.tb,mingain)  
  8.     if tree.fb.results==None:  
  9.         prune(tree.fb,mingain)  
  10.     #如果两个分支都是叶节点,则判断它们是否应该被合并  
  11.     if tree.tb.results!=None and tree.fb.results!=None:  
  12.         #构造合并后的数据集  
  13.         tb,fb=[],[]  
  14.         for v,c in tree.tb.results.items():#results是字典,items方法返回相应的键和值  
  15.             tb+=[[v]]*c  
  16.         for v,c in tree.fb.results.items():  
  17.             fb+=[[v]]*c  
  18.         #检查熵的减少情况  
  19.         delta=entropy(tb+fb)-(entropy(tb)+entropy(fb)/2)  
  20.         #上句和书中保持一下,结果也与书中一致。但是为什么不是下句呢?  
  21.         #delta=entropy(tb+fb)-(entropy(tb)+entropy(fb))/2  
  22.         #难道不是左右节点的熵的和,再除以2?为什么单独对entropy(fb)的熵除以2呢?  
  23.         #反正我估计是书错了  
  24.         if delta<mingain:  
  25.             #合并分支  
  26.             tree.tb,tree.fb=None,None  
  27.             tree.results=uniquecounts(tb+fb)  

执行代码:
[python]  view plain  copy
  1. tree=buildtree(my_data)  
  2. prune(tree,1.0)  
  3. printtree(tree)  

结果:
[python]  view plain  copy
  1. >>>   
  2. 0:google?   
  3. T-> 3:21?   
  4.  T-> {'Premium'3}  
  5.  F-> 2:yes?   
  6.   T-> {'Basic'1}  
  7.   F-> {'None'1}  
  8. F-> {'None'6'Basic'5}  
另除以2的问题:
[python]  view plain  copy
  1. >>> a=(2+2/2)  
  2. >>> print a  
  3. 3  
  4. >>> a=(2+2)/2  
  5. >>> print a  
  6. 2  

结果为数值类型时

上实例中,最后的结果:高级功能、基本功能、None。但是有时候最后的结果可能是数字,比如说对房屋价格的预测,这还是一连续的数字。
这个问题,我现在就不深究了,解决办法就是用方差来替代熵或者基尼不纯度。书中也有对房屋价格的预测的例子,需要时再研究。
[python]  view plain  copy
  1. #使用数字来作为一颗决策树的结果时,我们可以使用方差来代替计算混杂程度的:熵和基尼不纯度  
  2. def variance(rows):  
  3.     if len(rows)==0:return 0  
  4.     data=[float(row[len(row)-1])for row in rows]  
  5.     mean=sum(data)/len(data)  
  6.     variance=sum([d-mean**2 for d in data])/len(data)  
  7.     return variance  

总结

  1. 决策树的推导过程也是非常有意义的,人们很多时候也通过推导过程获得更多信息
  2. 给予受训练的模型进行了解释,也就是因为它来自google网站,所以是使用高级功能的用户
  3. 决策不仅能够处理分类:是否读过FAQ(yes或者no),还可以处理数值类型:浏览网页次数
  4. 即使缺失了部分用户数据还是能够作出预测
  5. 如果结果非常多,决策树非常复杂,可能失效。实例中只有三个,效果比较好
  6. 处理数值的列的时候,不能够仅仅大于或者小于,也许还有更为复杂的组合。比如,某两列的差、某一列的大于一个数且小于一个数。如果组合一多,树会非常复杂
  7. 大量输入和输出的情况下,决策树未必是非常好的组合。
  8. 多个数值列,存在复杂关系时,比如金融数据时,决策树不是好选择
  9. 决策树适合大量分类明确,数值分界点明确的数据处理

对项目的启示

难道我们能做一个决策树,用来判断用户.....没想出来。呼。。。比如对房屋价格进行预测,房屋的面积大小、卫生间的数量都有着明确的划分,多少平就是多少平,几个浴室就是几个浴室。貌似不太可能用来预测用户是否喜欢这首歌吧?

源代码

[python]  view plain  copy
  1. # -*- coding: cp936 -*-  
  2. my_data=[['slashdot','USA','yes',18,'None'],  
  3.         ['google','France','yes',23,'Premium'],  
  4.         ['digg','USA','yes',24,'Basic'],  
  5.         ['kiwitobes','France','yes',23,'Basic'],  
  6.         ['google','UK','no',21,'Premium'],  
  7.         ['(direct)','New Zealand','no',12,'None'],  
  8.         ['(direct)','UK','no',21,'Basic'],  
  9.         ['google','USA','no',24,'Premium'],  
  10.         ['slashdot','France','yes',19,'None'],  
  11.         ['digg','USA','no',18,'None'],  
  12.         ['google','UK','no',18,'None'],  
  13.         ['kiwitobes','UK','no',19,'None'],  
  14.         ['digg','New Zealand','yes',12,'Basic'],  
  15.         ['slashdot','UK','no',21,'None'],  
  16.         ['google','UK','yes',18,'Basic'],  
  17.         ['kiwitobes','France','yes',19,'Basic']]  
  18.   
  19. #这是决策树的表达形式:一个一个的节点。每一个节点有五个属性。  
  20. class decisionnode:  
  21.     def __init__(self,col=-1,value=None,results=None,tb=None,fb=None):  
  22.         self.col=col #被判断条件的对应的列的索引号,如图中:来源网站,是否阅读过FAQ的列序号  
  23.         self.value=value#什么情况下被判定为真?就看这个value的值,如果value为Yes(对应是否阅读过FAQ),也是Yes的时候为真。如果是大于20(对应浏览页数)那就是大于20为真  
  24.         self.results=results#最终结果,只有叶节点才有这个值,分类结果或者判定的结果  
  25.         self.tb=tb#也是一个decisionnode,t代表true,为true的话,就走这个节点。叶节点,没有这个,但是有结果  
  26.         self.fb=fb#也是一个decisionnode,t代表false,为false的话,就走这个节点。叶节点,没有这个,但是有结果  
  27.   
  28.   
  29. #针对某一列,对数据进行拆分成两个set,一个set代表选true的时候,一个set代表选false的时候  
  30. def divideset(rows,column,value):  
  31.     #定义了一个新函数,用这个函数去判断每一行数据是属于第一组(true),还是第二组(false)  
  32.     split_function=None  
  33.     #根据value的值,如果value是数字的话,一般都是大于某个数,如果value是布尔的话,那就是为true。  
  34.     #为了使这个函数既能够,又能够接受数值类的判断,又能够结果布尔值,是与否的判断,才如此的。  
  35.     if isinstance(value,int) or isinstance(value,float):  
  36.         split_function=lambda row:row[column]>=value#lambda创建一个新函数,该函数的接受的参数为row,函数内容为row[column]>=value  
  37.     else:  
  38.         split_function=lambda row:row[column]==value  
  39.   
  40.     #将数据集根据上面的函数,以及为真条件,判断,并返回  
  41.     set1=[row for row in rows if split_function(row)]#用split_function函数判断一下,如果成功就是放在set1  
  42.     set2=[row for row in rows if not split_function(row)]#用split_function函数判断一下,如果失败就是放在set2  
  43.     return (set1,set2)  
  44.   
  45. #对可能产生的最终判定结果做一个统计,一般来说,最后一列就是最终判定结果,比如该用户是使用基本功能还是高级功能,还是没什么需求  
  46. #不仅要统计有什么结果,还要统计出现的次数  
  47. def uniquecounts(rows):  
  48.     results={}  
  49.     for row in rows:  
  50.         #计数结果一般在最后一列  
  51.         r=row[len(row)-1]  
  52.         if r not in results:results[r]=0#不存在就新加一列  
  53.         results[r]+=1#存在就+1  
  54.     return results  
  55.   
  56. #函数接受一个数据集,然后计算其混杂程度。  
  57. #将这样的数学思维转为成代码是一件非常困难的事。  
  58. #利用集合中每一项结果出现的次数除以集合的总行数计算出该结果的概率  
  59. #出现k1的概率和不是k1的时候(k2)概率相乘,再依次把所有的这种的情况相加  
  60. #就可以得到:某一行数据被随机分配到错误结果的总概率  
  61. def giniimpurity(rows):  
  62.     total=len(rows)  
  63.     counts=uniquecounts(rows)  
  64.     imp=0  
  65.     for k1 in counts:  
  66.         p1=float(counts[k1])/total  
  67.         for k2 in counts:  
  68.             if k1==k2:continue  
  69.             p2=float(counts[k2])/total  
  70.             imp+=p1*p2  
  71.     return imp #返回值越高,表示越容易被分到其他类,也就越混杂,那么0代表拆分结果最为理想  
  72.   
  73. #函数接受一个数据集,然后计算其混杂程度。使用熵来计算  
  74. #熵遍历所有可能的结果的概率除以总行数的概率p,然后将所有的p做计算:p*log(p),再将所有的这个结果加起来  
  75. def entropy(rows):  
  76.     from math import log  
  77.     log2=lambda x:log(x)/log(2)  
  78.     results=uniquecounts(rows)  
  79.     #计算熵  
  80.     ent=0.0  
  81.     for r in results.keys():  
  82.         p=float(results[r])/len(rows)  
  83.         ent=ent-p*log2(p)#看样子熵算出来是一个负数  
  84.     return ent#熵越大,混乱度越高,如此,一个集合都的结果都一样的话,那么熵应该为0  
  85.   
  86. def buildtree(rows,scoref=entropy):  
  87.     if len(rows)==0:return decisionnode()#就是一个空节点呗  
  88.     current_score=scoref(rows)  
  89.   
  90.     #定义一些变量以方便记录最佳的拆分的条件  
  91.     best_gain=0.0  
  92.     best_criteria=None#标准,准则  
  93.     best_sets=None  
  94.     #最后一列是用来存放结果的,本例中就是用户使用了高级功能、基本功能、没有,所以在选最佳属性的时候会忽略掉这一列  
  95.     column_count=len(rows[0])-1  
  96.     for col in range(0,column_count):  
  97.         #在当前列中,形成一个不同值构成的序列,也就是说这一列有多少种可能的取值  
  98.         column_values={}  
  99.         for row in rows:  
  100.             column_values[row[col]]=1#好像这样可以去重,比如某一列有两个yes的话,但是最终最会一个yes在集合中  
  101.         #对这一列中的每一个词,都尝试一次数据的拆分  
  102.         for value in column_values.keys():  
  103.             (set1,set2)=divideset(rows,col,value)  
  104.   
  105.             #计算信息增益  
  106.             p=float(len(set1))/len(rows)#计算出set1的权重,也就是set1的行数除以总行数  
  107.             gain=current_score-p*scoref(set1)-(1-p)*scoref(set2)  
  108.             if gain>best_gain and len(set1)>0 and len(set2)>0:  
  109.                 best_gain=gain  
  110.                 best_criteria=(col,value)  
  111.                 best_sets=(set1,set2)  
  112.     #创建子分支  
  113.     if best_gain>0:#大于0表示还可以创建下面的分支  
  114.         trueBranch=buildtree(best_sets[0])  
  115.         falseBranch=buildtree(best_sets[1])  
  116.         return decisionnode(col=best_criteria[0],value=best_criteria[1],tb=trueBranch,fb=falseBranch)  
  117.     else:#不大于0,就是等于0,那么就这个集合不用再划分子集合了,就是叶节点,叶节点带有结果。  
  118.         return decisionnode(results=uniquecounts(rows))  
  119.   
  120. def printtree(tree,indent=''):  
  121.     #这是一个叶节点吗?  
  122.     if tree.results!=None:  
  123.         print str(tree.results)  
  124.     else:  
  125.         #打印判断条件  
  126.         print str(tree.col)+':'+str(tree.value)+'? '  
  127.   
  128.         #打印分支  
  129.         print indent+'T->',#print语句默认的会在后面加上 换行,加了逗号之后 换行 就变成了 空格  
  130.         printtree(tree.tb,indent+' ')  
  131.         print indent+'F->',#print语句默认的会在后面加上 换行,加了逗号之后 换行 就变成了 空格  
  132.         printtree(tree.fb,indent+' ')  
  133.   
  134. def getwidth(tree):  
  135.     if tree.tb==None and tree.fb==Nonereturn 1  
  136.     return getwidth(tree.tb)+getwidth(tree.fb)# 统计了有多少个子节点  
  137. def getdepth(tree):  
  138.     if tree.tb==None and tree.fb==None:return 0  
  139.     return max(getdepth(tree.tb),getdepth(tree.fb))+1#每多一层就会加1  
  140. from PIL import Image,ImageDraw  
  141. def drawtree(tree,jpeg='tree.jpg'):  
  142.     w=getwidth(tree)*100+120  
  143.     h=getdepth(tree)*100+120  
  144.   
  145.     img=Image.new('RGB',(w,h),(255,255,255))  
  146.     draw=ImageDraw.Draw(img)  
  147.   
  148.     drawnode(draw,tree,w/2,20)  
  149.     img.save(jpeg,'JPEG')  
  150.   
  151. def drawnode(draw,tree,x,y):  
  152.     if tree.results==None:  
  153.         #得到每个分支的宽度  
  154.         w1=getwidth(tree.fb)*100  
  155.         w2=getwidth(tree.tb)*100  
  156.   
  157.         #确定此节点所要占据的总空间  
  158.         left=x-(w1+w2)/2#确定左边界  
  159.         right=x+(w1+w2)/2#确定右边界  
  160.   
  161.         #绘制判断条件的字符串  
  162.         draw.text((x-20,y-10),str(tree.col)+':'+str(tree.value),(0,0,0))  
  163.   
  164.         #绘制到分支的连线  
  165.         draw.line((x,y,left+w1/2,y+100),fill=(255,0,0))  
  166.         draw.line((x,y,right-w2/2,y+100),fill=(255,0,0))  
  167.   
  168.         #绘制分支的节点  
  169.           
  170.         drawnode(draw,tree.fb,left+w1/2,y+100)  
  171.         drawnode(draw,tree.tb,right-w2/2,y+100)  
  172.     else:  
  173.         txt=' \n'.join(['%s:%d' %v for v in tree.results.items()])#results是个字典,items方法将所有的字典项以列表的方式返回,列表中的项由(key,value)组成,返回项无特殊顺序。  
  174.         draw.text(((x-20),y),txt,(0,0,0))  
  175.   
  176. #接受一个新的数据,然后让其决策树对其分类。  
  177. def classify(observation,tree):  
  178.     if tree.results!=None:  
  179.         return tree.results  
  180.     else:  
  181.         v=observation[tree.col]#拿到需要判断的那一列的数值  
  182.         branch=None  
  183.         if isinstance(v,int) or isinstance(v,float):  
  184.             if v>=tree.value:branch=tree.tb  
  185.             else: branch=tree.fb  
  186.         else:  
  187.             if v==tree.value:branch=tree.tb  
  188.             else: branch=tree.fb  
  189.         return classify(observation,branch)  
  190.   
  191. #接受一个新的数据,然后让其决策树对其分类。  
  192. #该函数可以接受该新数据中缺失了需要判断的数据,返回给用户的最终结果上面,我们会给加上一个权重,就是该分支的数据占所有数据的比例。  
  193. def mdclassify(observation,tree):  
  194.     if tree.results!=None:  
  195.         return tree.results  
  196.     else:  
  197.         v=observation[tree.col]#tree.col是当前节点需要判断的值,v是取出了需要被分类的数据的具体的值  
  198.         if v==None:#如果需要判断的值缺失  
  199.             tr,fr=mdclassify(observation,tree.tb),mdclassify(observation,tree.fb)  
  200.             tcount=sum(tr.values())#tr,fr是一个字典,而其中values()是用一个列表返回所有的字典中的键值对的值。  
  201.             fcount=sum(fr.values())  
  202.             tw=float(tcount)/(tcount+fcount)#这是一个权重,而这个权重的某一结果的行数占全部行数的比例  
  203.             fw=float(fcount)/(tcount+fcount)  
  204.             result={}  
  205.             for k,v in tr.items():result[k]=v*tw  
  206.             for k,v in fr.items():  
  207.                 if k not in result:result[k]=0  
  208.                 result[k] +=v*fw  
  209.             return result  
  210.         else:#如果需要判断的值不缺失的话  
  211.             if isinstance(v,int) or isinstance(v,float):#如果需要判断的数是数值型,那么就是大于或者小于  
  212.                 if v>=tree.value:branch=tree.tb  
  213.                 else: branch=tree.fb  
  214.             else:#如果需要判断的数是布尔型,,那么就是是或者否  
  215.                 if v==tree.value:branch=tree.tb  
  216.                 else:branch=tree.fb  
  217.             return mdclassify(observation,branch)  
  218.                   
  219.   
  220. #剪枝函数,基本思想是,判断某一个枝节点的两个叶节点能否合并  
  221. #依据:合并后的熵只有微弱的增加,增加的程度小于mingain。也就是把这个枝节点拆了熵也降低不了多少  
  222. #mingain人为定的阈值  
  223. def prune(tree,mingain):  
  224.     #如果该分支不是叶节点,则对其进行剪枝操作  
  225.     if tree.tb.results==None:  
  226.         prune(tree.tb,mingain)  
  227.     if tree.fb.results==None:  
  228.         prune(tree.fb,mingain)  
  229.     #如果两个分支都是叶节点,则判断它们是否应该被合并  
  230.     if tree.tb.results!=None and tree.fb.results!=None:  
  231.         #构造合并后的数据集  
  232.         tb,fb=[],[]  
  233.         for v,c in tree.tb.results.items():#results是字典,items方法返回相应的键和值  
  234.             tb+=[[v]]*c  
  235.         for v,c in tree.fb.results.items():  
  236.             fb+=[[v]]*c  
  237.         #检查熵的减少情况  
  238.         delta=entropy(tb+fb)-(entropy(tb)+entropy(fb)/2)  
  239.         #上句和书中保持一下,结果也与书中一致。但是为什么不是下句呢?  
  240.         #delta=entropy(tb+fb)-(entropy(tb)+entropy(fb))/2  
  241.         #难道不是左右节点的熵的和,再除以2?为什么单独对entropy(fb)的熵除以2呢?  
  242.         #反正我估计是书错了  
  243.         if delta<mingain:  
  244.             #合并分支  
  245.             tree.tb,tree.fb=None,None  
  246.             tree.results=uniquecounts(tb+fb)  
  247.   
  248.   
  249.   
  250. #使用数字来作为一颗决策树的结果时,我们可以使用方差来代替计算混杂程度的:熵和基尼不纯度  
  251. def variance(rows):  
  252.     if len(rows)==0:return 0  
  253.     data=[float(row[len(row)-1])for row in rows]  
  254.     mean=sum(data)/len(data)  
  255.     variance=sum([d-mean**2 for d in data])/len(data)  
  256.     return variance  
  257.   
  258. tree=buildtree(my_data)  
  259. print mdclassify(['google','France',None,None],tree)  

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值