注:代码来源于GitHub ljpzzz
源码网址
全部代码展示(后面有关键代码讲解)
from itertools import product
import numpy as np
import matplotlib.pyplot as plt
#Python 中常用的机器学习库。这里从它里面导入了datasets模块,用于加载一些内置的示例数据集
from sklearn import datasets
#导入了DecisionTreeClassifier类,用于构建决策树分类模型。
from sklearn.tree import DecisionTreeClassifier
# 仍然使用自带的iris数据(经典鸢尾花)
iris = datasets.load_iris()
#切片操作,取所有行的索引为0和2的两列,作为输入数据
X = iris.data[:, [0, 2]]
#输出数据是鸢尾花的分类标签,y是一个一维数组
y = iris.target
# 训练模型,限制树的最大深度4
clf = DecisionTreeClassifier(max_depth=4)
# 拟合模型
clf.fit(X, y)
# 画图,x_min是所有的第一列数据的最小值-1,x_max是第一列数据的最大值+1,表示绘图区域的边界
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
#这一行同理,确定y的上下边界
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
#使用np.meshgrid函数结合np.arange函数生成网格数据。
# np.arange(x_min, x_max, 0.1)会生成在x_min到x_max范围内,步长为 0.1 的一维数组,
# np.meshgrid则将这两个分别对应x轴和y轴方向的一维数组转换为二维的网格数据,即xx和yy
#这样说比较模糊,后面有例子,建议不懂先跳到后面看例子
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.1),
np.arange(y_min, y_max, 0.1))
#将xx和yy用ravel展平成两个一维数组,再用np.c_按列拼接成二维数组
#这样得到的np.c_[xx.ravel(), yy.ravel()],就是一个二维数组
#它的每一行是两个位,分别是花萼和花瓣长度两个特征
#只有这种形状才能用clf预测,得到预测结果Z,是一个一维数组
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
#将Z变形为xx的形状,即变回二维数组
Z = Z.reshape(xx.shape)
#根据xx和yy确定的网格坐标绘制等高线图,Z是每个点预测的类别,也是分类依据,
# 某个坐标上的颜色取决于这一坐标对应的预测类别
plt.contourf(xx, yy, Z, alpha=0.4)
#绘制散点图,x为第一个特征,y为第二个特征,c=y表示散点的颜色按照y来分,
# y是每个特征点对应的实际类别,不同类的点颜色不同
plt.scatter(X[:, 0], X[:, 1], c=y, alpha=0.8)
plt.show()
关键部分讲解(适合小白)
1.xx和yy是什么意思
向代码中添加print语句,观察在meshgrid发生作用前后,发生了什么变化
print(np.arange(x_min, x_max, 0.1))
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.1),
np.arange(y_min, y_max, 0.1))
print(xx)
输出如下
可以看到上面是np.arange()对应的输出,是步长0.1的一维数组,合情合理,下面是什么呢?
观察发现,每行都是相同的,且与第一个输出一样,可以把二维数组理解为一个坐标网的若干个坐标点,而数值表示的是该点横坐标的数值,事实上也是如此。而这张二维数组的行数,与meshgrid的第二个参数数组的元素个数有关。
再看yy
print(np.arange(y_min, y_max, 0.1))
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.1),
np.arange(y_min, y_max, 0.1))
print(yy)
发现第二个输出的每列均相同,都与第一个输出相同。数值表示坐标网的每个坐标点对应的纵坐标。
2.ravel和np.c_的作用
首先,ravel用于展平数组,如下。
print(xx)
print()
print(xx.ravel())
输出结果
可见这个有点像np.meshgrid的逆过程。yy同理。
至于np.c_,是为了按列将两个一维数组拼接起来,如下。
print(xx.ravel())
print(yy.ravel())
print(np.c_[xx.ravel(), yy.ravel()])
输出结果
也试一下按行拼接,即np.r_
print(xx.ravel())
print(yy.ravel())
print(np.r_[xx.ravel(), yy.ravel()])
输出结果
3.Zreshape前后变化
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
print(Z)
Z = Z.reshape(xx.shape)
print()
print(Z)
因此,展平的目的是为了拼出一个符合决策树输入要求的数组,即一个样本为一行,有若干行的形式。而meshgrid的目的是为了画图。
可视化呈现
如果把树的深度调为2,也就是分类能力没有那么强,对应的图如下
如果将树的深度调为10,此时分类能力很强,但会过拟合,图上很直观地反映了这一点。