💖亲爱的技术爱好者们,热烈欢迎来到 Kant2048 的博客!我是 Thomas Kant,很开心能在优快云上与你们相遇~💖
本博客的精华专栏:
【自动化测试】 【测试经验】 【人工智能】 【Python】
Sklearn 实战:IRIS 数据集分类与测试集预估全过程
在机器学习入门阶段,IRIS 鸢尾花数据集因其简洁、结构清晰而成为分类任务的经典示例。
本文将基于 scikit-learn
展示如何从 数据预处理 → 模型训练 → 测试集预测 → 可视化分析 → 模型持久化 全流程完成一次标准分类任务。
🌸 一、加载 IRIS 数据集
from sklearn.datasets import load_iris
# 加载数据
iris = load_iris()
X = iris.data # 特征(花萼长宽、花瓣长宽)
y = iris.target # 标签(类别)
print(f"特征维度: {X.shape}, 标签维度: {y.shape}")
IRIS 数据共包含 150 条记录,目标类别为 setosa
、versicolor
和 virginica
。
✂️ 二、划分训练集与测试集
from sklearn.model_selection import train_test_split
# 划分训练集与测试集(80%:20%)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)
print(f"训练集样本数: {X_train.shape[0]}")
print(f"测试集样本数: {X_test.shape[0]}")
💡 使用
stratify=y
可以保持各类别在训练集与测试集中的比例一致,有助于提升模型稳定性。
⚙️ 三、数据标准化处理
from sklearn.preprocessing import StandardScaler
# 标准化:均值为0,方差为1
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)
📌 KNN 属于基于距离的模型,对特征量纲敏感,因此建议对数据进行标准化处理。
🤖 四、训练 KNN 分类模型
from sklearn.neighbors import KNeighborsClassifier
# 初始化 KNN 模型,设定邻居数为3
model = KNeighborsClassifier(n_neighbors=3)
model.fit(X_train_scaled, y_train)
🔍 五、在测试集上进行预测与评估
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
# 模型预测
y_pred = model.predict(X_test_scaled)
# 输出前10个预测结果
print("\n前10个样本预测结果:")
for i in range(10):
print(f"样本 {i+1}: 预测 = {iris.target_names[y_pred[i]]}, 实际 = {iris.target_names[y_test[i]]}")
模型评估:
# 准确率
acc = accuracy_score(y_test, y_pred)
print(f"\n测试集准确率: {acc:.4f}")
# 分类报告
print("\n分类报告:")
print(classification_report(y_test, y_pred, target_names=iris.target_names))
混淆矩阵可视化:
import matplotlib.pyplot as plt
import seaborn as sns
plt.figure(figsize=(8, 6))
cm = confusion_matrix(y_test, y_pred)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=iris.target_names, yticklabels=iris.target_names)
plt.xlabel('预测类别')
plt.ylabel('实际类别')
plt.title('混淆矩阵')
plt.tight_layout()
📊 六、可视化测试结果与特征重要性
import numpy as np
plt.figure(figsize=(15, 5))
# 子图1:测试集预测 vs 实际
plt.subplot(1, 2, 1)
plt.scatter(X_test[:, 0], X_test[:, 1], c=y_pred, cmap='viridis', marker='o', s=50, label='预测')
plt.scatter(X_test[:, 0], X_test[:, 1], c=y_test, cmap='coolwarm', marker='x', s=100, alpha=0.5, label='实际')
plt.xlabel("花萼长度")
plt.ylabel("花萼宽度")
plt.title("测试集预测 vs 实际")
plt.legend()
plt.grid(True)
# 子图2:特征重要性(方差衡量)
plt.subplot(1, 2, 2)
feature_importance = np.var(X_train, axis=0)
plt.bar(iris.feature_names, feature_importance)
plt.xlabel('特征')
plt.ylabel('方差')
plt.title('特征重要性(基于方差)')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()
🧠 注意:KNN 并不显式提供特征重要性,这里用特征方差做近似分析,帮助我们理解哪些特征差异度更大。
💾 七、模型保存与加载(持久化)
import pickle
# 保存模型
with open('iris_knn_model.pkl', 'wb') as f:
pickle.dump(model, f)
# 加载模型
with open('iris_knn_model.pkl', 'rb') as f:
loaded_model = pickle.load(f)
# 使用加载模型预测新样本
new_sample = np.array([[5.1, 3.5, 1.4, 0.2]])
new_sample_scaled = scaler.transform(new_sample)
prediction = loaded_model.predict(new_sample_scaled)
print(f"\n示例样本预测类别: {iris.target_names[prediction][0]}")
📦
pickle
可用于模型的持久化和部署,是快速构建预测 API 的第一步。
总结
通过本文完整实战演示,我们完成了一个标准的机器学习分类任务:
- 使用 IRIS 数据集进行建模与评估
- 引入数据标准化提升 KNN 性能
- 利用 Seaborn 热力图和 Matplotlib 子图直观展示模型效果
- 添加特征重要性分析、模型保存与加载等实用功能
这不仅是对 sklearn 流程的复盘,也为后续迁移至更复杂任务(如 pipeline、调参、模型集成)打下基础。
🎉如果你觉得这篇文章对你有帮助,欢迎点赞 👍、收藏 ⭐ 和关注我!也欢迎评论区留言交流!