一、数据导入
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['Microsoft YaHei']
# 数据导入
datas = pd.read_csv(r'E:/iris.csv')
datas.columns = ['A', 'B', 'C', 'D', 'categories'] # A、B、C、D为花瓣长度等特征
datas.head(10)
A | B | C | D | categories | |
---|---|---|---|---|---|
0 | 4.9 | 3.0 | 1.4 | 0.2 | Iris-setosa |
1 | 4.7 | 3.2 | 1.3 | 0.2 | Iris-setosa |
2 | 4.6 | 3.1 | 1.5 | 0.2 | Iris-setosa |
3 | 5.0 | 3.6 | 1.4 | 0.2 | Iris-setosa |
4 | 5.4 | 3.9 | 1.7 | 0.4 | Iris-setosa |
5 | 4.6 | 3.4 | 1.4 | 0.3 | Iris-setosa |
6 | 5.0 | 3.4 | 1.5 | 0.2 | Iris-setosa |
7 | 4.4 | 2.9 | 1.4 | 0.2 | Iris-setosa |
8 | 4.9 | 3.1 | 1.5 | 0.1 | Iris-setosa |
9 | 5.4 | 3.7 | 1.5 | 0.2 | Iris-setosa |
二、数据预处理
查看数据集的总体参数:
# 去重,查看数据类型,查看是否有缺省值
datas.drop_duplicates(inplace=True)
datas.info()
<class 'pandas.core.frame.DataFrame'>
Index: 146 entries, 0 to 148
Data columns (total 5 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 A 146 non-null float64
1 B 146 non-null float64
2 C 146 non-null float64
3 D 146 non-null float64
4 categories 146 non-null object
dtypes: float64(4), object(1)
memory usage: 6.8+ KB
以上结论:每一列都没有缺失值,数据类型也正确。
查看数据的统计特征:
# 查看数据的统计特征(简单检查是否存在异常值)
datas.describe()
A | B | C | D | |
---|---|---|---|---|
count | 146.000000 | 146.000000 | 146.000000 | 146.000000 |
mean | 5.861644 | 3.052740 | 3.796575 | 1.215753 |
std | 0.829562 | 0.436948 | 1.753987 | 0.755822 |
min | 4.300000 | 2.000000 | 1.000000 | 0.100000 |
25% | 5.100000 | 2.800000 | 1.600000 | 0.300000 |
50% | 5.800000 | 3.000000 | 4.400000 | 1.300000 |
75% | 6.400000 | 3.300000 | 5.100000 | 1.800000 |
max | 7.900000 | 4.400000 | 6.900000 | 2.500000 |
以上结论:特征的统计数据大致都在合理范围内,无明显异常值。
查看鸢尾花的类别有哪些(即有哪些分类标签),并查看占比:
# 查看鸢尾花有哪些类别,以及各类别的数量
datas.groupby(['categories'])['categories'].count()
categories
Iris-setosa 47
Iris-versicolor 50
Iris-virginica 49
Name: categories, dtype: int64
以上结论:鸢尾花共有3种类别,分别为“setosa”、“versicolor”、“virginica”,且每种的数据占比是均衡的。
对“categories”列进行标签编码(即用0、1、2等代替类别,变成计算机能对读取的格式):
# 标签编码
datas['categories'] = datas['categories'].map({
'Iris-setosa':0, 'Iris-versicolor':1, 'Iris-virginica':2
})
输出特征关系矩阵,查看特征间的相互关系:
# 特征之间关系的矩阵图
sns.pairplot(datas, palette='viridis', hue='categories')
plt.show()
特征关系矩阵的对角线上的图表示该特征(A、B、…)的分布直方图。非对角先上的图可以查看两个特征之间的相关关系,比如第一行第二幅图,可以看出特征A和特征B的相关关系不是很强,右上角那幅图可以看出特征A和D直接有较强的相关关系。
除了可以直观的查看相关关系(非对角线),还可以根据散点的颜色和位置,判断使用这两个特征来进行分类的效果,比如第一行第二幅图,可以看出只特征A和特征B并不能很好的进行分类,而右上角那幅图,使用特征A和D的分类效果比第一行第二幅图好点。
特征关系图的结论:我们可以只使用特征C和特征D进行分类(特征选择),分类效果较好。
特征选择与数据标准化(消除量纲影响):
from sklearn.preprocessing import StandardScaler
# 特征选择
X = datas[['C','D']]
y = datas['categories']
# 数据标准化
scaler = StandardScaler()
X_ = scaler.fit_transform(X)
三、建模与预测
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X_, y, test_size=0.3, random_state=40)
# 建模
knn = KNeighborsClassifier(n_neighbors=3)
knn.fit(X_train, y_train)
# 预测
y_pred = knn.predict(X_test)
四、模型评估
from sklearn.metrics import classification_report
# 输出分类报告
print(classification_report(y_test, y_pred))
precision recall f1-score support
0 1.00 1.00 1.00 11
1 1.00 0.92 0.96 13
2 0.95 1.00 0.98 20
accuracy 0.98 44
macro avg 0.98 0.97 0.98 44
weighted avg 0.98 0.98 0.98 44
以上结论:分类报告中,模型的准确率为98%,表明KNN算法在鸢尾花数据集上表现良好。各类别的精确率(precision)、召回率(recall)和F1分数(F1-score)均较高,说明该模型对各类别的分类效果较好。
五、分类结果可视化
# 创建一个DataFrame对象用于结果可视化
view = pd.DataFrame(X_test, columns=['C', 'D'])
view['y_test']=y_test.reset_index(drop=True)
view['predict'] = y_pred
view['is_correct'] = view['y_test'] == view['predict']
# 绘制散点图
sns.scatterplot(
x='C',
y='D',
data=view,
hue='predict', # 根据预测类别着色
style='is_correct', # 根据是否正确分类选择标记样式
markers={True: 'o', False: 'X'},
palette='viridis',
s=180
)
plt.title('分类结果可视化')
plt.xlabel('特征C')
plt.ylabel('特征D', rotation=0)
plt.show()
结论:分类结果可视化显示,只有一个数据被错误预测,说明该拟合的模型在该数据集上的分类预测效果很好。
# 文章如有错误,欢迎大家指正。我们下期再见叭