逻辑回归在鸢尾花数据集上的二分类实践
逻辑回归是机器学习中最经典的分类算法之一,广泛应用于二分类问题。在本文中,我们将通过一个具体的案例——基于鸢尾花(Iris)数据集的二分类任务,来深入探讨逻辑回归的实现和应用。我们将使用 Python 和 scikit-learn
库来完成这一任务,并通过可视化手段展示模型的性能。
一、鸢尾花数据集简介
鸢尾花数据集是机器学习领域中最著名的数据集之一,由 R.A. Fisher 在 1936 年首次引入。该数据集包含 150 个样本,每个样本有 4 个特征,分别表示鸢尾花的 花萼长度(sepal length)、花萼宽度(sepal width)、花瓣长度(petal length)和花瓣宽度(petal width)。样本分为三个类别,每个类别有 50 个样本,分别对应三种鸢尾花:Setosa、Versicolour 和 Virginica。
在本次实验中,我们将专注于一个二分类问题:区分 Virginica 类和非 Virginica 类的鸢尾花。为此,我们将数据集中的类别 2(Virginica)标记为正类(1),其余类别标记为负类(0)。
数据集特征信息如下:
============== ==== ==== ======= ===== ==================== Min Max Mean SD Class Correlation ============== ==== ==== ======= ===== ==================== sepal length: 4.3 7.9 5.84 0.83 0.7826 sepal width: 2.0 4.4 3.05 0.43 -0.4194 petal length: 1.0 6.9 3.76 1.76 0.9490 (high!) petal width: 0.1 2.5 1.20 0.76 0.9565 (high!) ============== ==== ==== ======= ===== ====================
二、数据预处理
首先,我们加载鸢尾花数据集,并提取花瓣宽度(petal width)作为特征,因为根据数据集的描述,花瓣宽度与类别标签的关联性最强,并对y进行二分化,处理后如下表所示:
Python复制
import numpy as np
from sklearn import datasets
from sklearn.linear_model import LogisticRegression
# 加载鸢尾花数据集
iris = datasets.load_iris()
X = iris['data'][:, 3:] # 只取花瓣宽度作为特征
y = (iris['target'] == 2).astype(np.int32) # 将类别 2 标记为 1,其余标记为 0
三、逻辑回归模型的训练
接下来,我们使用 scikit-learn
的 LogisticRegression
类来训练逻辑回归模型。为了确保模型能够收敛,我们选择使用随机平均梯度下降(solver='sag'
)作为优化算法,并设置最大迭代次数为 1000。
Python复制
# 训练逻辑回归模型
binary_classifier = LogisticRegression(solver='sag', max_iter=1000)
binary_classifier.fit(X, y)
四、模型预测与可视化
为了评估模型的性能,我们生成一个测试集,并使用模型进行预测。我们将预测结果的概率值和类别标签进行可视化,以直观地展示模型的分类效果。
Python复制
# 生成测试集
X_new = np.linspace(0, 3, 1000).reshape(-1, 1) # 在 0 到 3 之间均匀取 1000 个点
# 预测测试集的概率值和类别标签
y_proba = binary_classifier.predict_proba(X_new) # 预测正例和负例的概率
y_hat = binary_classifier.predict(X_new) # 预测类别标签
# 可视化预测结果
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 6))
plt.plot(X_new, y_proba[:, 1], label='Probability of Virginica (Class 1)', color='blue')
plt.plot(X_new, y_hat, label='Predicted Class', color='red', linestyle='--')
plt.scatter(X[y == 0], np.zeros(len(X[y == 0])), color='green', label='Non-Virginica (Class 0)')
plt.scatter(X[y == 1], np.ones(len(X[y == 1])), color='purple', label='Virginica (Class 1)')
plt.xlabel('Petal Width (cm)')
plt.ylabel('Probability / Class Label')
plt.title('Logistic Regression on Iris Dataset')
plt.legend()
plt.grid(True)
plt.show()
可视化结果
运行上述代码后,你可看到相应可视化结果:
图像说明:
-
蓝色曲线 表示模型预测为 Virginica 类(正类)的概率值。
-
红色虚线 表示模型的预测类别标签(0 或 1)。
-
绿色点 表示非 Virginica 类的样本(负类)。
-
紫色点 表示 Virginica 类的样本(正类)。
可以看到,随着花瓣宽度的增加,模型预测为 Virginica 类的概率逐渐增加,并在某个阈值处将类别标签从 0 转变为 1。这表明花瓣宽度是一个非常有效的特征,能够很好地区分 Virginica 类与其他类别的鸢尾花。
五、总结
通过本次实验,我们成功地使用逻辑回归模型对鸢尾花数据集进行了二分类任务。我们选择了花瓣宽度作为特征,并通过可视化手段展示了模型的分类效果。逻辑回归模型在处理二分类问题时表现出色,能够提供直观的概率预测和清晰的决策边界。
此外,我们还了解到数据预处理和特征选择的重要性。在实际应用中,选择合适的特征可以显著提高模型的性能。希望本文能够帮助你更好地理解逻辑回归及其在实际问题中的应用。如果你对逻辑回归的其他方面感兴趣,比如正则化、多分类问题等,欢迎继续探索和学习!