一、载入数据
载入常用库
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import numpy as np
载入数据
IRIS_TRAINING = 'iris_training.csv'
IRIS_TEST = 'iris_test.csv'
training_set = tf.contrib.learn.datasets.base.load_csv_with_header(
filename=IRIS_TRAINING
target_dtype=http://np.int
features_dtype=np.float(32)
test_set = tf.contrib.learn.datasets.base.load_csv_with_header(
filename=IRIS_TEST
target_dtype=http://np.int
feature_dtype=np.float(32)
二、构建神经网络分类器
features_columns = [tf.contrib.layers.real_valued_column('', dimension=4) #数据连续,4特征
classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,
hidden_units=[10, 20, 10], #3层隐藏层,分别10,20,10个神经元
n_classes=3, #3目标类别
model_dir='/tmp/iris_model') #保存训练记录
三、训练模型
classifier.fit(x=training_set.data, y=training_set.target, steps=2000)四、评估模型
accuracy_score = classifier.evaluate(x=test_set.data, y=test_set.target)['accuracy']
print('Accuracy: {0:f}'.format(accuracy_score))
五、对新样本进行分类
new_samples = np.array(
[[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype=float)
y = list(classifier.predict(new_samples, as_iterable=True))
print('Prediction: {}'.format(str(y)))
Iris数据集上的TensorFlow分类
本文介绍如何使用TensorFlow构建神经网络分类器来预测鸢尾花的种类。通过加载鸢尾花数据集并训练模型,最终实现对新样本的有效分类。文中详细展示了数据载入、模型构建、训练及评估的过程。
739

被折叠的 条评论
为什么被折叠?



