坑① 构建子树
被ppt坑了一下,ppt里写构建子树的时候要选数据集最大的点来构建子树。那么就意味着每一层中只有一个节点能建子树(该点的数据集最大)。但实际上这句话是屁话,只要节点没有分完,就要往下建树。
遍历子节点建树就ok,管它什么数据集大小。
坑② python 画树
-
中文乱码
这个什么pygraphviz 库,首先中文不支持,你想要显示中文,得
fontname=“SimHei”
不然乱码,服气
-
重名节点
我想要画的叶子节点的名字是重复的,比如 好瓜出现了两次。
蛋疼的是,Pygraphviz 中节点是用name 来区分的,一个Name 有唯一的点。这样的话,不可能同名节点不能重复出现。
好吧,那我就降低要求,用前缀区分一下。
坑③ sqlite expert
查询数据集,我就想用sql 语句来做。第一次用不太熟练,其实是挺简单的。我会sql 语句,用python 字符串的把语句拼接一下就行。
坑的是sqlite expert 软件。我本来是excel 文件转的csv再想导到数据库里。在sqlite expert 楞是没找到import csv选项,差点怀疑人生了。最后发现我用的是 personal 版本,只有 professional 版本有这个功能。personal 版本不配拥有吗??
说完了,talk is cheap,show me the code:
id3.py
# -*- coding: utf-8 -*-
import math
import pandas as pd
from selectTable import select_table
import pygraphviz as pgv
file = pd.read_excel('watermelon20.xlsx')
index = file.columns.values
G = pgv.AGraph(directed=True, rankdir="TB",
compound=True, normalize=True, encoding='UTF-8')
cnt = 0
print(index)
cases_num = float(file.shape[0])
feature_num = file.shape[1]-1
print("data_num=",cases_num,"feature=",feature_num)
h_feature = {}
def calEntropy(confidition,feature):
h_feature = {}
data = select_table(confidition)
cases = len(data)
num_good = float(data.count(('是',)))
num_bad = float(len(data)-num_good)
hd = - (math.log(num_good/cases,2)*\
num_good/cases +\
math.log(num_bad/cases,2)*num_bad/cases)
for idx in feature:
labels = []
if idx == "编号" or idx == "好瓜":
continue
for label in select_table(confidition, label=idx, distict=True):
if len(label) == 1:
labels.append(label[0])
for label in labels:
now_confidition = confidition
if not now_confidition:
now_confidition = " where "
else:
now_confidition = now_confidition + " and "
now_confidition = now_confidition + " " + idx + " == '" +label+"'"
feature_data = select_table(condition=now_confidition)
#print("condition",now_confidition)
k_ = float(len(feature_data))/cases
label_good = float(feature_data.count(('是',)))
#print("feature_data",feature_data)
m_ = float(label_good)/len(feature_data) # 好瓜率
# 只有好瓜,坏瓜两种情况,直接加
hk = 0
if label_good != 0: # good
hk = m_ *math.log(m_,2)
if len(feature_data)-label_good != 0: # bad
hk = hk + (1-m_) *math.log((1-m_),2)
h_feature[idx] = h_feature.get(idx, 0) + k_ * hk
h_feature[idx] = h_feature.get(idx, 0) * -1
h_feature[idx] = hd - h_feature[idx]
return h_feature
def checkSameLable(confidition):
data = select_table(condition=confidition)
cases = len(data)
num_good = float(data.count(('是',)))
if num_good == cases:
return True,'好瓜'
elif num_good ==0:
return True, '坏瓜'
else:
return False,None
# bfs
def buildId3Tree(confidition,root,feature):
# 移除 root from 特征集
global cnt
feature.remove(root)
# get labels
labels = []
for label in select_table(confidition, label=root, distict=True):
if len(label) == 1:
labels.append(label[0])
for label in labels:
now_confidition = confidition
if not now_confidition:
now_confidition = " where "
else:
now_confidition = now_confidition + " and "
now_confidition = now_confidition + " " + root + " = '" +label+"'"
# 如果属于同一类返回
flag, nflag = checkSameLable(now_confidition)
if flag:
nflag = "LeafNode:"+str(cnt)+" "+nflag
G.add_node(nflag,fontname="SimHei")
G.add_edge(root, nflag, label=label,fontname="SimHei",color="black", style="dashed", penwidth=1.5)
cnt = cnt + 1
continue
# 计算增益比
h_feature = calEntropy(now_confidition, feature)
# 选增益比最大的特征作为节点
maxidx = max(h_feature, key=h_feature.get)
print("maxidx=",maxidx)
# 以该节点为根构建子树
G.add_node((maxidx),label=label)
G.add_edge(root, maxidx,label=label)
buildId3Tree(now_confidition, root, feature)
def createId3TreeRoot():
# 遍历完所有的特征时, 返回出现次数最多的标签(叶子)
# 计算增益比
h_feature = calEntropy(None, index)
# 选增益比最大的特征作为节点
maxidx = max(h_feature, key=h_feature.get)
G.add_node(maxidx,fontname="SimHei")
return maxidx
root = createId3TreeRoot()
print("root=",root)
buildId3Tree(None, root, index.tolist())
G.layout()
G.draw("id3.png", prog="dot")
有很多冗余,我不想改了,以后可能闲的没事改一下吧。凑合看吧,反正能画图。
数据集:
selectTable.py
import sqlite3
conn = sqlite3.connect('table.db')
cur = conn.cursor()
print("Opened database successfully")
def select_table(condition=None,label="好瓜",sql=None,distict = False):
if sql:
cur.execute(sql)
return cur.fetchmany()
sql = "select "
if distict:
sql = sql + " DISTINCT "
sql = sql + label+" from test"
if condition:
sql = sql + " " +condition
cur.execute(sql)
#print(cur.fetchmany())
return cur.fetchall()
画出来的图:
c4.5
import matplotlib.pyplot as plt
# -*- coding: utf-8 -*-
import math
import pandas as pd
from selectTable import select_table
import pygraphviz as pgv
file = pd.read_excel('watermelon20.xlsx')
index = file.columns.values
G = pgv.AGraph(directed=True, rankdir="TB",
compound=True, normalize=True, encoding='UTF-8')
cnt = 0
print(index)
cases_num = float(file.shape[0])
feature_num = file.shape[1]-1
print("data_num=",cases_num,"feature=",feature_num)
hd = 0.
h_feature = {}
use_feature={}
use_feature_Leaf={}
def calEntropy(confidition,feature):
h_feature = {}
h_a_feature = {}
data = select_table(confidition)
cases = len(data)
num_good = float(data.count(('是',)))
num_bad = float(len(data)-num_good)
hd = - (math.log(num_good/cases,2)*\
num_good/cases +\
math.log(num_bad/cases,2)*num_bad/cases)
for idx in feature:
labels = []
if idx == "编号" or idx == "好瓜":
continue
for label in select_table(confidition, label=idx, distict=True):
if len(label) == 1:
labels.append(label[0])
for label in labels:
now_confidition = confidition
if not now_confidition:
now_confidition = " where "
else:
now_confidition = now_confidition + " and "
now_confidition = now_confidition + " " + idx + " == '" +label+"'"
feature_data = select_table(condition=now_confidition)
#print("condition",now_confidition)
k_ = float(len(feature_data))/cases
label_good = float(feature_data.count(('是',)))
#print("feature_data",feature_data)
m_ = float(label_good)/len(feature_data) # 好瓜率
# 只有好瓜,坏瓜两种情况,直接加
hk = 0
hf = 0
if label_good != 0: # good
hk = m_ *math.log(m_,2)
if len(feature_data)-label_good != 0: # bad
hk = hk + (1-m_) *math.log((1-m_),2)
h_feature[idx] = h_feature.get(idx, 0) + k_ * hk
if k_:
h_a_feature[idx] = h_a_feature.get(idx,0) + k_* math.log(k_,2)
h_feature[idx] = h_feature.get(idx, 0) * -1
h_a_feature[idx]= h_a_feature.get(idx,0)* -1
h_feature[idx] = (hd - h_feature[idx])/h_a_feature[idx]
return h_feature
def checkSameLable(confidition):
data = select_table(condition=confidition)
cases = len(data)
num_good = float(data.count(('是',)))
if num_good == cases:
return True,'好瓜'
elif num_good ==0:
return True, '坏瓜'
else:
return False,None
# bfs
def buildId3Tree(confidition,root,feature):
# 移除 root from 特征集
global cnt
feature.remove(root)
# get labels
labels = []
for label in select_table(confidition, label=root, distict=True):
if len(label) == 1:
labels.append(label[0])
for label in labels:
now_confidition = confidition
if not now_confidition:
now_confidition = " where "
else:
now_confidition = now_confidition + " and "
now_confidition = now_confidition + " " + root + " = '" +label+"'"
# 如果属于同一类返回
flag, nflag = checkSameLable(now_confidition)
if flag:
nflag = "LeafNode:"+str(cnt)+" "+nflag
G.add_node(nflag,fontname="SimHei")
G.add_edge(root, nflag, label=label,fontname="SimHei",color="black", style="dashed", penwidth=1.5)
cnt = cnt + 1
continue
# 计算增益比
h_feature = calEntropy(now_confidition, feature)
# 选增益比最大的特征作为节点
maxidx = max(h_feature, key=h_feature.get)
print("maxidx=",maxidx)
# 以该节点为根构建子树
G.add_node((maxidx),label=label)
G.add_edge(root, maxidx,label=label)
buildId3Tree(now_confidition, root, feature)
def createId3TreeRoot():
# 遍历完所有的特征时, 返回出现次数最多的标签(叶子)
# 计算增益比
h_feature = calEntropy(None, index)
# 选增益比最大的特征作为节点
maxidx = max(h_feature, key=h_feature.get)
G.add_node(maxidx,fontname="SimHei")
return maxidx
def init():
global use_feature
for x in index:
use_feature[x]=False
use_feature_Leaf[x] = 0
init()
root = createId3TreeRoot()
print("root=",root)
buildId3Tree(None, root, index.tolist())
G.layout()
G.draw("c45.png", prog="dot")
里面函数名忘记改了,凑合看吧,反正实现是c4.5
图:
cart
import matplotlib.pyplot as plt
# -*- coding: utf-8 -*-
import math
import pandas as pd
from selectTable import select_table
import pygraphviz as pgv
file = pd.read_excel('watermelon20.xlsx')
index = file.columns.values
G = pgv.AGraph(directed=True, rankdir="TB",
compound=True, normalize=True, encoding='UTF-8')
cnt = 0
print(index)
cases_num = float(file.shape[0])
feature_num = file.shape[1]-1
print("data_num=",cases_num,"feature=",feature_num)
hd = 0.
h_feature = {}
use_feature={}
use_feature_Leaf={}
def calGini(confidition,feature):
gini_name= {}
gini_num= {}
data = select_table(confidition)
cases = len(data)
num_good = float(data.count(('是',)))
for idx in feature:
labels = []
if idx == "编号" or idx == "好瓜":
continue
for label in select_table(confidition, label=idx, distict=True):
if len(label) == 1:
labels.append(label[0])
minGini = 999
minName=""
for label in labels:
temp = 999
now_confidition = confidition
not_confidition = confidition
if not now_confidition:
now_confidition = " where "
not_confidition = " where "
else:
now_confidition = now_confidition + " and "
not_confidition = not_confidition + " and "
now_confidition = now_confidition + " " + idx + " = '" +label+"'"
not_confidition = not_confidition + " " + idx + " != '" +label+"'"
feature_data = select_table(condition=now_confidition)
other_feature_data = select_table(condition=not_confidition)
#print("condition",now_confidition)
k_ = float(len(feature_data))/cases
label_good = float(feature_data.count(('是',)))
other_label_good = float(other_feature_data.count(('是',)))
other_ = float(other_label_good)/len(other_feature_data)
m_ = float(label_good)/len(feature_data) # 好瓜率
temp = k_*(2*m_*(1-m_))+(1-k_)*(2*other_*(1-other_))
if temp<minGini:
minGini = temp
minName = label
gini_name[idx]= minName
gini_num[idx]= minGini
return gini_num,gini_name
def checkSameLable(confidition):
data = select_table(condition=confidition)
cases = len(data)
num_good = float(data.count(('是',)))
if num_good == cases:
return True,'好瓜'
elif num_good ==0:
return True, '坏瓜'
else:
return False,None
# bfs
def buildGiniTree(confidition,root,label,feature):
# 移除 root from 特征集
global cnt
feature.remove(root)
# get labels
labels = ['True','False']
now_confidition = confidition
not_confidition = confidition
if not now_confidition:
now_confidition = " where "
not_confidition = " where "
else:
now_confidition = now_confidition + " and "
not_confidition = not_confidition + " and "
now_confidition = now_confidition + " " + root + " = '" +label+"'"
not_confidition = not_confidition + " " + root + " != '" +label+"'"
# 如果属于同一类返回
flag, nflag = checkSameLable(now_confidition)
not_flag, not_nflag = checkSameLable(not_confidition)
if flag:
nflag = "LeafNode:"+str(cnt)+" "+nflag
G.add_node(nflag,fontname="SimHei")
G.add_edge(root+label, nflag, label="是",fontname="SimHei",color="black", style="dashed", penwidth=1.5)
cnt = cnt + 1
if not not_flag:
return
else:
# 计算gini
gini_num,gini_name = calGini(now_confidition, feature)
# 选gini最小的特征作为节点
minidx = min(gini_num, key=gini_num.get)
G.add_node(minidx + gini_name[minidx], fontname="SimHei")
G.add_edge(root+label, minidx + gini_name[minidx], label="否", fontname="SimHei", color="black", style="dashed", penwidth=1.5)
buildGiniTree(now_confidition,minidx,gini_name[minidx],feature)
if not_flag:
nflag = "LeafNode:" + str(cnt) + " " + nflag
G.add_node(nflag, fontname="SimHei")
G.add_edge(root+label, nflag, label="否", fontname="SimHei", color="black", style="dashed", penwidth=1.5)
cnt = cnt + 1
return
else:
# 计算gini
gini_num, gini_name = calGini(now_confidition, feature)
# 选gini最小的特征作为节点
minidx = min(gini_num, key=gini_num.get)
G.add_node(minidx + gini_name[minidx], fontname="SimHei")
G.add_edge(root+label, minidx + gini_name[minidx], label="否", fontname="SimHei", color="black", style="dashed", penwidth=1.5)
buildGiniTree(not_confidition, minidx, gini_name[minidx], feature)
def createTreeRoot():
# 计算gini
gini_num, gini_name = calGini(None, index)
# 选gini最小的特征作为节点
minidx = min(gini_num, key=gini_num.get)
G.add_node(minidx + gini_name[minidx], fontname="SimHei")
return minidx,gini_name[minidx]
root ,label= createTreeRoot()
print("root=",root)
print("label=",label)
buildGiniTree(None, root, label,index.tolist())
G.layout()
G.draw("cart.png", prog="dot")
我太菜了,搞半天搞出来这个。主要是开头难,id3写了好久建树,写完id3,后面两个算法半个小时不到改一改就行了。
想放github 上的,怎么github挂了呀?时好时坏的