class Branch:
no=0
depth=1
column=''
entropy=0
samples=0
value=[]
branch_positive=None
branch_negative=None
no_positive=0
no_negative=0
number=0
def decision_tree_inner(data, label, depth, max_depth=3):
global number
branch = Branch()
branch.no=number
number=number+1
branch.depth=depth
branch.samples=data.shape[0]
n_positive=data[data[label]==1].shape[0]
branch.value=[branch.samples-n_positive,n_positive]
branch.entropy=information_entropy(branch.value)
best_feature = find_best_feature(data, label)
branch.column=best_feature[0]
new_entropy=best_feature[1]
if depth==max_depth or branch.column=='':
branch.no_positive=number
number=number+1
branch.no_negative=number
number=number+1
return branch
else:
data_negative=best_feature[3]
branch.branch_negative=decision_tree_inner(data_negative, label, depth+1, max_depth=max_depth)
data_positive=best_feature[2]
branch.branch_positive=decision_tree_inner(data_positive, label, depth+1, max_depth=max_depth)
return branch
def decision_tree(data, label, max_depth=3):
number=0
entropy=data_entropy(data)
tree=decision_tree_inner(data, label, 0, max_depth=3)
return tree
my_dt = decision_tree(data_train, 'low', max_depth=2)
from sklearn.tree import DecisionTreeClassifier
dt = DecisionTreeClassifier(criterion='entropy', max_depth=3)
model=dt.fit(X_train, y_train)
决策树
最新推荐文章于 2023-07-23 17:32:00 发布