一段简单的决策树可视化代码详解

注:代码来源于GitHub ljpzzz

源码网址

machinelearning/classic-machine-learning/decision_tree_classifier_1.ipynb at master · ljpzzz/machinelearning · GitHub

 全部代码展示(后面有关键代码讲解)

from itertools import product

import numpy as np
import matplotlib.pyplot as plt

#Python 中常用的机器学习库。这里从它里面导入了datasets模块,用于加载一些内置的示例数据集
from sklearn import datasets

#导入了DecisionTreeClassifier类,用于构建决策树分类模型。
from sklearn.tree import DecisionTreeClassifier

# 仍然使用自带的iris数据(经典鸢尾花)
iris = datasets.load_iris()
#切片操作,取所有行的索引为0和2的两列,作为输入数据
X = iris.data[:, [0, 2]]
#输出数据是鸢尾花的分类标签,y是一个一维数组
y = iris.target

# 训练模型,限制树的最大深度4
clf = DecisionTreeClassifier(max_depth=4)
# 拟合模型
clf.fit(X, y)

# 画图,x_min是所有的第一列数据的最小值-1,x_max是第一列数据的最大值+1,表示绘图区域的边界
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
#这一行同理,确定y的上下边界
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1

#使用np.meshgrid函数结合np.arange函数生成网格数据。
# np.arange(x_min, x_max, 0.1)会生成在x_min到x_max范围内,步长为 0.1 的一维数组,
# np.meshgrid则将这两个分别对应x轴和y轴方向的一维数组转换为二维的网格数据,即xx和yy
#这样说比较模糊,后面有例子,建议不懂先跳到后面看例子
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.1),
                     np.arange(y_min, y_max, 0.1))

#将xx和yy用ravel展平成两个一维数组,再用np.c_按列拼接成二维数组
#这样得到的np.c_[xx.ravel(), yy.ravel()],就是一个二维数组
#它的每一行是两个位,分别是花萼和花瓣长度两个特征
#只有这种形状才能用clf预测,得到预测结果Z,是一个一维数组
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])

#将Z变形为xx的形状,即变回二维数组
Z = Z.reshape(xx.shape)

#根据xx和yy确定的网格坐标绘制等高线图,Z是每个点预测的类别,也是分类依据,
# 某个坐标上的颜色取决于这一坐标对应的预测类别 
plt.contourf(xx, yy, Z, alpha=0.4)
#绘制散点图,x为第一个特征,y为第二个特征,c=y表示散点的颜色按照y来分,
# y是每个特征点对应的实际类别,不同类的点颜色不同
plt.scatter(X[:, 0], X[:, 1], c=y, alpha=0.8)
plt.show()

关键部分讲解(适合小白)

1.xx和yy是什么意思

向代码中添加print语句,观察在meshgrid发生作用前后,发生了什么变化

print(np.arange(x_min, x_max, 0.1))
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.1),
                     np.arange(y_min, y_max, 0.1))
print(xx)

输出如下

​​​​​

可以看到上面是np.arange()对应的输出,是步长0.1的一维数组,合情合理,下面是什么呢?

观察发现,每行都是相同的,且与第一个输出一样,可以把二维数组理解为一个坐标网的若干个坐标点,而数值表示的是该点横坐标的数值,事实上也是如此。而这张二维数组的行数,与meshgrid的第二个参数数组的元素个数有关。

再看yy

print(np.arange(y_min, y_max, 0.1))
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.1),
                     np.arange(y_min, y_max, 0.1))
print(yy)

发现第二个输出的每列均相同,都与第一个输出相同。数值表示坐标网的每个坐标点对应的纵坐标。

2.ravel和np.c_的作用

首先,ravel用于展平数组,如下。

print(xx)
print()
print(xx.ravel())

输出结果

可见这个有点像np.meshgrid的逆过程。yy同理。

至于np.c_,是为了按列将两个一维数组拼接起来,如下。

print(xx.ravel())
print(yy.ravel())
print(np.c_[xx.ravel(), yy.ravel()])

输出结果

也试一下按行拼接,即np.r_

print(xx.ravel())
print(yy.ravel())
print(np.r_[xx.ravel(), yy.ravel()])

输出结果 

3.Zreshape前后变化 

Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
print(Z)
Z = Z.reshape(xx.shape)
print()
print(Z)

因此,展平的目的是为了拼出一个符合决策树输入要求的数组,即一个样本为一行,有若干行的形式。而meshgrid的目的是为了画图。

可视化呈现

如果把树的深度调为2,也就是分类能力没有那么强,对应的图如下

 如果将树的深度调为10,此时分类能力很强,但会过拟合,图上很直观地反映了这一点。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值