决策树的三种实现(绘图):id3, c4.5 , cart

坑① 构建子树

被ppt坑了一下,ppt里写构建子树的时候要选数据集最大的点来构建子树。那么就意味着每一层中只有一个节点能建子树(该点的数据集最大)。但实际上这句话是屁话,只要节点没有分完,就要往下建树。
遍历子节点建树就ok,管它什么数据集大小。

坑② python 画树

  1. 中文乱码

    这个什么pygraphviz 库,首先中文不支持,你想要显示中文,得

    fontname=“SimHei”

    不然乱码,服气

  2. 重名节点
    我想要画的叶子节点的名字是重复的,比如 好瓜出现了两次。
    蛋疼的是,Pygraphviz 中节点是用name 来区分的,一个Name 有唯一的点。这样的话,不可能同名节点不能重复出现。
    好吧,那我就降低要求,用前缀区分一下。
    在这里插入图片描述

坑③ sqlite expert

查询数据集,我就想用sql 语句来做。第一次用不太熟练,其实是挺简单的。我会sql 语句,用python 字符串的把语句拼接一下就行。
坑的是sqlite expert 软件。我本来是excel 文件转的csv再想导到数据库里。在sqlite expert 楞是没找到import csv选项,差点怀疑人生了。最后发现我用的是 personal 版本,只有 professional 版本有这个功能。personal 版本不配拥有吗??

说完了,talk is cheap,show me the code:

id3.py

# -*- coding: utf-8 -*-
import  math
import pandas as pd
from selectTable import select_table
import pygraphviz as pgv
file = pd.read_excel('watermelon20.xlsx')
index = file.columns.values
G = pgv.AGraph(directed=True, rankdir="TB",
               compound=True, normalize=True, encoding='UTF-8')
cnt = 0
print(index)
cases_num = float(file.shape[0])
feature_num = file.shape[1]-1
print("data_num=",cases_num,"feature=",feature_num)

h_feature = {}
def calEntropy(confidition,feature):
    h_feature = {}
    data = select_table(confidition)
    cases = len(data)
    num_good = float(data.count(('是',)))
    num_bad = float(len(data)-num_good)
    hd = - (math.log(num_good/cases,2)*\
            num_good/cases +\
            math.log(num_bad/cases,2)*num_bad/cases)
    for idx in feature:
        labels = []
        if idx == "编号" or idx == "好瓜":
            continue
        for label in select_table(confidition, label=idx, distict=True):
            if len(label) == 1:
                labels.append(label[0])
        for label in labels:
            now_confidition = confidition
            if not now_confidition:
                now_confidition = " where "
            else:
                now_confidition = now_confidition + " and "
            now_confidition = now_confidition + " " + idx + " == '" +label+"'"
            feature_data = select_table(condition=now_confidition)
            #print("condition",now_confidition)
            k_ = float(len(feature_data))/cases
            label_good = float(feature_data.count(('是',)))
            #print("feature_data",feature_data)
            m_ = float(label_good)/len(feature_data) # 好瓜率
            # 只有好瓜,坏瓜两种情况,直接加
            hk = 0
            if label_good != 0: # good
                hk = m_ *math.log(m_,2)
            if len(feature_data)-label_good != 0: # bad
                hk = hk + (1-m_) *math.log((1-m_),2)
            h_feature[idx] = h_feature.get(idx, 0) + k_ * hk
        h_feature[idx] = h_feature.get(idx, 0) * -1
        h_feature[idx] = hd - h_feature[idx]
    return h_feature


def checkSameLable(confidition):
    data = select_table(condition=confidition)
    cases = len(data)
    num_good = float(data.count(('是',)))
    if num_good == cases:
        return True,'好瓜'
    elif num_good ==0:
        return True, '坏瓜'
    else:
        return False,None


# bfs
def buildId3Tree(confidition,root,feature):
    # 移除 root from 特征集
    global  cnt
    feature.remove(root)
    # get labels
    labels = []
    for label in select_table(confidition, label=root, distict=True):
        if len(label) == 1:
            labels.append(label[0])
    for label in labels:
        now_confidition = confidition
        if not now_confidition:
            now_confidition =  " where "
        else:
            now_confidition = now_confidition + " and "
        now_confidition = now_confidition + " " + root + " = '" +label+"'"
        # 如果属于同一类返回
        flag, nflag = checkSameLable(now_confidition)
        if flag:
            nflag = "LeafNode:"+str(cnt)+" "+nflag
            G.add_node(nflag,fontname="SimHei")
            G.add_edge(root, nflag, label=label,fontname="SimHei",color="black", style="dashed", penwidth=1.5)
            cnt = cnt + 1
            continue
        # 计算增益比
        h_feature = calEntropy(now_confidition, feature)
        # 选增益比最大的特征作为节点
        maxidx = max(h_feature, key=h_feature.get)
        print("maxidx=",maxidx)
        # 以该节点为根构建子树
        G.add_node((maxidx),label=label)
        G.add_edge(root, maxidx,label=label)
        buildId3Tree(now_confidition, root, feature)

def createId3TreeRoot():
    # 遍历完所有的特征时, 返回出现次数最多的标签(叶子)
    # 计算增益比
    h_feature = calEntropy(None, index)
    # 选增益比最大的特征作为节点
    maxidx = max(h_feature, key=h_feature.get)
    G.add_node(maxidx,fontname="SimHei")
    return maxidx


root = createId3TreeRoot()
print("root=",root)
buildId3Tree(None, root, index.tolist())
G.layout()
G.draw("id3.png", prog="dot")

有很多冗余,我不想改了,以后可能闲的没事改一下吧。凑合看吧,反正能画图。
数据集:
在这里插入图片描述

selectTable.py


import sqlite3
conn = sqlite3.connect('table.db')
cur = conn.cursor()
print("Opened database successfully")
def select_table(condition=None,label="好瓜",sql=None,distict = False):
    if sql:
        cur.execute(sql)
        return cur.fetchmany()
    sql = "select "
    if distict:
        sql = sql + " DISTINCT "
    sql = sql + label+" from test"

    if condition:
        sql = sql + " " +condition

    cur.execute(sql)
    #print(cur.fetchmany())
    return cur.fetchall()

画出来的图:
在这里插入图片描述

c4.5

import matplotlib.pyplot as plt
# -*- coding: utf-8 -*-
import  math
import pandas as pd
from selectTable import select_table
import pygraphviz as pgv
file = pd.read_excel('watermelon20.xlsx')
index = file.columns.values
G = pgv.AGraph(directed=True, rankdir="TB",
               compound=True, normalize=True, encoding='UTF-8')
cnt = 0
print(index)
cases_num = float(file.shape[0])
feature_num = file.shape[1]-1
print("data_num=",cases_num,"feature=",feature_num)

hd = 0.
h_feature = {}
use_feature={}
use_feature_Leaf={}
def calEntropy(confidition,feature):
    h_feature = {}
    h_a_feature = {}
    data = select_table(confidition)
    cases = len(data)
    num_good = float(data.count(('是',)))
    num_bad = float(len(data)-num_good)
    hd = - (math.log(num_good/cases,2)*\
            num_good/cases +\
            math.log(num_bad/cases,2)*num_bad/cases)
    for idx in feature:
        labels = []
        if idx == "编号" or idx == "好瓜":
            continue
        for label in select_table(confidition, label=idx, distict=True):
            if len(label) == 1:
                labels.append(label[0])
        for label in labels:
            now_confidition = confidition
            if not now_confidition:
                now_confidition = " where "
            else:
                now_confidition = now_confidition + " and "
            now_confidition = now_confidition + " " + idx + " == '" +label+"'"
            feature_data = select_table(condition=now_confidition)
            #print("condition",now_confidition)
            k_ = float(len(feature_data))/cases
            label_good = float(feature_data.count(('是',)))
            #print("feature_data",feature_data)
            m_ = float(label_good)/len(feature_data) # 好瓜率
            # 只有好瓜,坏瓜两种情况,直接加
            hk = 0
            hf = 0
            if label_good != 0: # good
                hk = m_ *math.log(m_,2)
            if len(feature_data)-label_good != 0: # bad
                hk = hk + (1-m_) *math.log((1-m_),2)
            h_feature[idx] = h_feature.get(idx, 0) + k_ * hk
            if k_:
                h_a_feature[idx] = h_a_feature.get(idx,0) + k_* math.log(k_,2)
        h_feature[idx] = h_feature.get(idx, 0) * -1
        h_a_feature[idx]= h_a_feature.get(idx,0)* -1
        h_feature[idx] = (hd - h_feature[idx])/h_a_feature[idx]
    return h_feature


def checkSameLable(confidition):
    data = select_table(condition=confidition)
    cases = len(data)
    num_good = float(data.count(('是',)))
    if num_good == cases:
        return True,'好瓜'
    elif num_good ==0:
        return True, '坏瓜'
    else:
        return False,None


# bfs
def buildId3Tree(confidition,root,feature):
    # 移除 root from 特征集
    global  cnt
    feature.remove(root)
    # get labels
    labels = []
    for label in select_table(confidition, label=root, distict=True):
        if len(label) == 1:
            labels.append(label[0])
    for label in labels:
        now_confidition = confidition
        if not now_confidition:
            now_confidition =  " where "
        else:
            now_confidition = now_confidition + " and "
        now_confidition = now_confidition + " " + root + " = '" +label+"'"
        # 如果属于同一类返回
        flag, nflag = checkSameLable(now_confidition)
        if flag:
            nflag = "LeafNode:"+str(cnt)+" "+nflag
            G.add_node(nflag,fontname="SimHei")
            G.add_edge(root, nflag, label=label,fontname="SimHei",color="black", style="dashed", penwidth=1.5)
            cnt = cnt + 1
            continue
        # 计算增益比
        h_feature = calEntropy(now_confidition, feature)
        # 选增益比最大的特征作为节点
        maxidx = max(h_feature, key=h_feature.get)
        print("maxidx=",maxidx)
        # 以该节点为根构建子树
        G.add_node((maxidx),label=label)
        G.add_edge(root, maxidx,label=label)
        buildId3Tree(now_confidition, root, feature)

def createId3TreeRoot():
    # 遍历完所有的特征时, 返回出现次数最多的标签(叶子)
    # 计算增益比
    h_feature = calEntropy(None, index)
    # 选增益比最大的特征作为节点
    maxidx = max(h_feature, key=h_feature.get)
    G.add_node(maxidx,fontname="SimHei")
    return maxidx

def init():
    global use_feature
    for x in index:
        use_feature[x]=False
        use_feature_Leaf[x] = 0


init()
root = createId3TreeRoot()
print("root=",root)
buildId3Tree(None, root, index.tolist())
G.layout()
G.draw("c45.png", prog="dot")

里面函数名忘记改了,凑合看吧,反正实现是c4.5
图:
在这里插入图片描述

cart

import matplotlib.pyplot as plt
# -*- coding: utf-8 -*-
import  math
import pandas as pd
from selectTable import select_table
import pygraphviz as pgv
file = pd.read_excel('watermelon20.xlsx')
index = file.columns.values
G = pgv.AGraph(directed=True, rankdir="TB",
               compound=True, normalize=True, encoding='UTF-8')
cnt = 0
print(index)
cases_num = float(file.shape[0])
feature_num = file.shape[1]-1
print("data_num=",cases_num,"feature=",feature_num)

hd = 0.
h_feature = {}
use_feature={}
use_feature_Leaf={}
def calGini(confidition,feature):
    gini_name= {}
    gini_num= {}

    data = select_table(confidition)
    cases = len(data)
    num_good = float(data.count(('是',)))

    for idx in feature:
        labels = []
        if idx == "编号" or idx == "好瓜":
            continue
        for label in select_table(confidition, label=idx, distict=True):
            if len(label) == 1:
                labels.append(label[0])
        minGini = 999
        minName=""
        for label in labels:
            temp = 999
            now_confidition = confidition
            not_confidition = confidition
            if not now_confidition:
                now_confidition = " where "
                not_confidition = " where "
            else:
                now_confidition = now_confidition + " and "
                not_confidition = not_confidition + " and "
            now_confidition = now_confidition + " " + idx + " = '" +label+"'"
            not_confidition = not_confidition + " " + idx + " != '" +label+"'"
            feature_data = select_table(condition=now_confidition)
            other_feature_data = select_table(condition=not_confidition)
            #print("condition",now_confidition)
            k_ = float(len(feature_data))/cases
            label_good = float(feature_data.count(('是',)))
            other_label_good = float(other_feature_data.count(('是',)))
            other_ = float(other_label_good)/len(other_feature_data)
            m_ = float(label_good)/len(feature_data) # 好瓜率
            temp = k_*(2*m_*(1-m_))+(1-k_)*(2*other_*(1-other_))
            if temp<minGini:
                minGini = temp
                minName = label
        gini_name[idx]=  minName
        gini_num[idx]=  minGini
    return gini_num,gini_name


def checkSameLable(confidition):
    data = select_table(condition=confidition)
    cases = len(data)
    num_good = float(data.count(('是',)))
    if num_good == cases:
        return True,'好瓜'
    elif num_good ==0:
        return True, '坏瓜'
    else:
        return False,None


# bfs
def buildGiniTree(confidition,root,label,feature):
    # 移除 root from 特征集
    global cnt
    feature.remove(root)
    # get labels
    labels = ['True','False']

    now_confidition = confidition
    not_confidition = confidition
    if not now_confidition:
        now_confidition =  " where "
        not_confidition =  " where "
    else:
        now_confidition = now_confidition + " and "
        not_confidition = not_confidition + " and "
    now_confidition = now_confidition + " " + root + " = '" +label+"'"
    not_confidition = not_confidition + " " + root + " != '" +label+"'"
    # 如果属于同一类返回
    flag, nflag = checkSameLable(now_confidition)
    not_flag, not_nflag = checkSameLable(not_confidition)
    if flag:
        nflag = "LeafNode:"+str(cnt)+" "+nflag
        G.add_node(nflag,fontname="SimHei")
        G.add_edge(root+label, nflag, label="是",fontname="SimHei",color="black", style="dashed", penwidth=1.5)
        cnt = cnt + 1
        if not not_flag:
            return
    else:
        # 计算gini
        gini_num,gini_name = calGini(now_confidition, feature)
        # 选gini最小的特征作为节点
        minidx = min(gini_num, key=gini_num.get)
        G.add_node(minidx + gini_name[minidx], fontname="SimHei")
        G.add_edge(root+label, minidx + gini_name[minidx], label="否", fontname="SimHei", color="black", style="dashed", penwidth=1.5)

        buildGiniTree(now_confidition,minidx,gini_name[minidx],feature)
    if not_flag:
        nflag = "LeafNode:" + str(cnt) + " " + nflag
        G.add_node(nflag, fontname="SimHei")
        G.add_edge(root+label, nflag, label="否", fontname="SimHei", color="black", style="dashed", penwidth=1.5)
        cnt = cnt + 1
        return
    else:
        # 计算gini
        gini_num, gini_name = calGini(now_confidition, feature)
        # 选gini最小的特征作为节点
        minidx = min(gini_num, key=gini_num.get)
        G.add_node(minidx + gini_name[minidx], fontname="SimHei")
        G.add_edge(root+label, minidx + gini_name[minidx], label="否", fontname="SimHei", color="black", style="dashed", penwidth=1.5)

        buildGiniTree(not_confidition, minidx, gini_name[minidx], feature)

def createTreeRoot():
    # 计算gini
    gini_num, gini_name = calGini(None, index)
    # 选gini最小的特征作为节点
    minidx = min(gini_num, key=gini_num.get)
    G.add_node(minidx + gini_name[minidx], fontname="SimHei")
    return minidx,gini_name[minidx]


root ,label= createTreeRoot()
print("root=",root)
print("label=",label)
buildGiniTree(None, root, label,index.tolist())
G.layout()
G.draw("cart.png", prog="dot")

在这里插入图片描述
我太菜了,搞半天搞出来这个。主要是开头难,id3写了好久建树,写完id3,后面两个算法半个小时不到改一改就行了。
想放github 上的,怎么github挂了呀?时好时坏的

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值