Premade Estimators
本文档介绍了 TensorFlow 编程环境,并展示了怎么用 Premade Estimators 来解决 Iris 分类问题。
文章目录
需要安装的包
在运行本文的代码之前,你需要安装以下包:
- 安装TensorFlow。
- 如果你是在virtualenv或Anaconda中安装的TensorFlow,请activate你的TensorFlow环境。
- 运行下面的代码安装、更新pandas包:
pip install pandas
获取本文的代码
通过以下步骤来获得本文的代码:
- 使用下面的命令将TensorFlow Model repository克隆到本地:
git clone https://github.com/tensorflow/models
- 将目录切到本文使用的代码的位置:
cd models/samples/core/get_started/
本文使用的代码是 premade_estimator.py。该程序使用 iris_data.py 代码去fetch训练数据。
运行本文的代码
使用下面的方式来运行本文的代码。例如:
python premade_estimator.py
程序会在训练过程中输出训练日志,然后程序会输出在测试集上的测试结果。
...
Prediction is "Setosa" (99.6%), expected "Setosa"
Prediction is "Versicolor" (99.8%), expected "Versicolor"
Prediction is "Virginica" (97.9%), expected "Virginica"
如果程序运行的结果有误,请进行如下检查:
- TensorFlow安装有没有问题
- TensorFlow的版本是否正确
- 确定activate了安装TensorFlow的环境
1. TensorFlow 编程环境介绍
在使用 TensorFlow 编程之前,让我们首先研究下 TF 编程环境。如下图所示,TensorFlow 提供了一个包含多个 API 层的编程堆栈:
我们强烈推荐使用下列 API 编写 TF 程序:
- Estimators:代表一个完整的模型。Estimator API 提供一些方法来训练模型、评估模型性能并生成预测。
- Datasets:构建数据输入管道。Dataset API 提供了一些方法来加载数据(E),操作数据(T),并将数据馈送到您的模型中(L)。Dataset API 与 Estimator API 完美兼容。注:构建数据输入管道的过程其实就是构建 ETL 的过程。
2. Iris 分类问题
Iris数据集包含了3种鸢尾花、150个样本,每个样本包含鸢尾花的四个特征值:花蕊长度(cm)、花蕊宽度(cm)、花瓣长度(cm)、花瓣宽度(cm)。
1.1 Iris数据集简介
Iris 数据集包含4个特征和1个标签。4个特征分别为:
- 花蕊长度(sepal length)
- 花蕊宽度(sepal width)
- 花瓣长度(petal length)
- 花瓣宽度(petal width)
标签表明了 Iris 的品种,共有三个品种:
- Iris setosa (0)
- Iris versicolor (1)
- Iris virginica (2)
下面的表是数据集的一个片段:
花蕊长度(cm) | 花蕊宽度(cm) | 花瓣长度(cm) | 花瓣宽度(cm) | 品种(标签) |
---|---|---|---|---|
5.1 | 3.3 | 1.7 | 0.5 | 0 (setosa) |
5.0 | 2.3 | 3.3 | 1.0 | 1 (versicolor) |
6.4 | 2.8 | 5.6 | 2.2 | 2 (virginica) |
1.2 分类算法
本文档使用一个 DNN 来实现对 Iris 的分类,该分类器的概况如下:
- 2 个隐藏层
- 每一个隐藏层包含 10 个节点
特征、隐藏层、预