本节我们采用IRIS数据集实现一个简单的二值分类器来预测一朵花是否为山鸢尾。
数据输入:花瓣长度、花瓣宽度
数据输出:山鸢尾标记为1,其他物种标记为0
模型:y=sigmoid(x1-(a*x2+b))
损失函数:sigmoid交叉熵损失函数
代码如下:
首先导入库并创建计算图会话
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
sess = tf.Session()
其次导入数据集,这里引用sklearn内的datasets里的iris数据集,特征选用花瓣长度和花瓣宽度,这两个特征在第3和第4列,同时将山鸢尾标记为1,其余物种标记为0
iris = datasets.load_iris()
data = iris.data[:,2:4]
data_target = np.array([1.0 if x==0 else 0.0 for x in iris.target])
print('data_example'.center(50,'-'))
print('iris_data:',data[1:5,:])
print('data_target:',data_target[1:5])
数据示例为:
-------------------data_example-------------------
iris_data: [[ 1.4 0.2]
[ 1.3 0.2]
[ 1.5 0.2]
[ 1.4 0.2]]
data_target: [ 1. 1. 1. 1.]
接下来声明批量训练大小。数据占位符和模型变量