在本练习中,您将从头开始实施决策树,并将其应用于蘑菇可食用还是有毒的分类任务。
1-包
首先,让我们运行下面的单元格来导入此分配过程中所需的所有包。
- numpy是在Python中处理矩阵的基本包。
- matplotlib是一个著名的Python绘图库。
- utils.py包含此赋值的辅助函数。您不需要修改此文件中的代码。
import numpy as np
import matplotlib.pyplot as plt
from public_tests import *
%matplotlib inline
2-问题陈述
假设你正在创办一家种植和销售野生蘑菇的公司。
- 由于并非所有蘑菇都是可食用的,因此您希望能够根据蘑菇的物理属性来判断其是否可食用或有毒。
- 您有一些现有的数据可以用于此任务。
你能用这些数据来帮助你确定哪些蘑菇可以安全销售吗?
注:所使用的数据集仅用于说明目的。它并不意味着成为识别食用蘑菇的指南。
3-数据集
您将从加载此任务的数据集开始。您收集的数据集如下:
- 你有10个蘑菇的例子。对于每个示例,都有
- “三个特征”
- “Cap Color”(棕色或红色)
- “Stalk Shape”(锥形或扩大)
- “Solitary”(是或否)
- “Label Edible”
- (1表示是,0表示有毒)
- “三个特征”
3.1一个热编码数据集
为了便于实现,我们对特性进行了热编码(将它们转换为0或1值特性)
因此
- X_train包含每个示例的三个特征
- 棕色(值1表示“棕色”的菌盖颜色,0表示“红色”的菌帽颜色)
- 锥形(值1指示“锥形茎形状”,0表示茎形状“扩大”)
- 孤立(值1表明“是”,0表明“否”)
- y_train表示蘑菇是否可食用
- y=1表示可食用
- y=0表示有毒
X_train = np.array([[1,1,1],[1,0,1],[1,0,0],[1,0,0],[1,1,1],[0,1,1],[0,0,0],[1,0,1],[0,1,0],[1,0,0]])
y_train = np.array([1,1,0,0,1,0,0,1,1,0])
查看变量
让我们更熟悉您的数据集。
- 一个好的开始是打印出每个变量,看看它包含什么。
下面的代码打印X_train的前几个元素和变量的类型。
print("First few elements of X_train:\n", X_train[:5])
print("Type of X_train:",type(X_train))
First few elements of X_train:
[[1 1 1]
[1 0 1]
[1 0 0]
[1 0 0]
[1 1 1]]
Type of X_train: <class 'numpy.ndarray'>
现在,让我们为y_train做同样的事情
print("First few elements of y_train:", y_train[:5])
print("Type of y_train:",type(y_train))