一、什么是鸢尾花数据集
鸢尾花数据集作为入门经典数据集。Iris数据集是常用的分类实验数据集,由Fisher, 1936收集整理。Iris也称鸢尾花卉数据集,是一类多重变量分析的数据集。数据集包含150个数据集,分为3类,每类50个数据,每个数据包含4个属性。可通过花萼长度,花萼宽度,花瓣长度,花瓣宽度4个属性预测鸢尾花卉属于(Setosa,Versicolour,Virginica)三个种类中的哪一类。
在三个类别中,其中有一个类别和其他两个类别是线性可分的。另外。在sklearn中已内置了此数据集。
from sklearn.datasets import load_iris #导入IRIS数据集
iris = load_iris() #特征矩阵
二、数据的载入
通过手动的方式将数据集进行读入
f = open(path)
2 x = []
3 y = []
4 for d in f:
5 d = d.strip()
6 if d:
7 d = d.split(',')
8 y.append(d[-1])
9 x.append(list(map(float, d[:-1])))
10 x = np.array(x)
11 y = np.array(y)
12 print(x)
13 y[y == 'Iris-setosa'] = 0
14 y[y == 'Iris-versicolor'] = 1
15 y[y == 'Iris-virginica'] = 2
16 print(y)
17 y = y.astype(dtype=np.int)
18 print(y)
通过pandas库来进行数据的读取
import numpy as np
2 from sklearn.linear_model import LogisticRegression
3 import matplotlib.pyplot as plt
4 import matplotlib as mpl
5 from sklearn import preprocessing
6 import pandas as pd
7 from sklearn.preprocessing import StandardScaler
8 from sklearn.pipeline import Pipeline
9
10 df = pd.read_csv(path, header=0)
11 x = df.values[:, :-1]
12 y = df.values[:, -1]
13 print('x = \n', x)
14 print('y = \n', y)
15 le = preprocessing.LabelEncoder()
16 le.fit(['Iris-setosa', 'Iris-versicolor', 'Iris-virginica'])
17 print(le.classes_)
18 y = le.transform(y)
19 print('Last Version, y = \n', y)
三、模型可视化
我们可以在所选特征的范围内,从最大值到最小值构建一系列的数据,使得它能覆盖整个的特征数据范围,然后预测这些值所属的分类,并给它们所在的区域
上色,这样我们就能够清楚的看到模型每个分类的区域了,具体的代码如下所示:
1 N, M = 500, 500 # 横纵各采样多少个值
2 x1_min, x1_max = x[:, 0].min(), x[:, 0].max() # 第0列的范围
3 x2_min, x2_max = x[:, 1].min(), x[:, 1].max() # 第1列的范围
4 t1 = np.linspace(x1_min, x1_max, N)
5 t2 = np.linspace(x2_min, x2_max, M)
6 x1, x2 = np.meshgrid(t1, t2) # 生成网格采样点
7 x_test = np.stack((x1.flat, x2.flat), axis=1) # 测试点
8
9 cm_light = mpl.colors.ListedColormap(['#77E0A0', '#FF8080', '#A0A0FF'])
10 cm_dark = mpl.colors.ListedColormap(['g', 'r', 'b'])
11 y_hat = lr.predict(x_test) # 预测值
12 y_hat = y_hat.reshape(x1.shape) # 使之与输入的形状相同
13 plt.pcolormesh(x1, x2, y_hat, cmap=cm_light) # 预测值的显示
14 plt.scatter(x[:, 0], x[:, 1], c=y.ravel(), edgecolors='k', s=50, cmap=cm_dark)
15 plt.xlabel('petal length')
16 plt.ylabel('petal width')
17 plt.xlim(x1_min, x1_max)
18 plt.ylim(x2_min, x2_max)
19 plt.grid()
20 plt.savefig('2.png')
21 plt.show()
1227

被折叠的 条评论
为什么被折叠?



