多分类任务的混淆矩阵处理

多分类任务的混淆矩阵处理

在多分类任务中,不适合使用PR曲线和ROC曲线来进行指标评价,但我们仍可以通过混淆矩阵来进行处理。可以通过matplotlib的matshow()函数,直观地展示分类结果的好坏。

先使用cross_val_predict得出各个分类值的分数

 y_train_pred = cross_val_predict(sgd_clf, X_train_scaled, y_train, cv= 3 )

再使用confusion_matrix()得出最终的混淆矩阵

conf_mx = confusion_matrix(y_train, y_train_pred)

然后使用 Matplotlib 的 matshow() 函数,将混淆矩阵以图像的方式呈现

plt.matshow(conf_mx, cmap=plt.cm.gray)

如下图所示,行代表了实际的类别,列代表了预测的结果,从图中可看出大致都在正对角线上,说明分类结果还不错。
在这里插入图片描述
但是我们应该关注仅包含误差数据的图像呈现,所以将混淆矩阵的每一个值除以相应类别的图片的总数目。这样子,你可以比较错误率,而不是绝对的错误数(这对大的类别不公平)

row_sums = conf_mx.sum(axis= 1 , keepdims= True )
norm_conf_mx = conf_mx / row_sums

然后用 0 来填充对角线(使正确的分类不可见),这样子就只保留了被错误分类的数据。

np.fill_diagonal(norm_conf_mx,  0 )
plt.matshow(norm_conf_mx, cmap=plt.cm.gray)

如下图所示,8,9列比较亮,说明有很多都被错误地分到了8,9类中去。相似的,第 8、9 行也相当亮,也就是说8,9类也经常被误以为是其他类别。
在这里插入图片描述
所以通过这个混淆矩阵图像,分析混淆矩阵通常可以给你提供深刻的见解去改善你的分类器。回顾这幅图,看样子你应该努力改善分类器在类别8 和类别 9 上的表现,和纠正 3/5 的混淆。

举例子,你可以尝试去收集更多的数据,或者你可以构造新的、有助于分类器的特征。举例子,写一个算法去数闭合的环(比如,数字 8 有两个环,数字 6 有一个, 5 没有)。又或者你可以预处理图片(比如,使用 Scikit-Learn,Pillow, OpenCV)去构造一个模式,比如闭合的环。

### 绘制SVM多分类的混淆矩阵 为了绘制支持向量机(SVM)模型在多分类任务中的混淆矩阵,可以按照如下方法操作。首先加载必要的库并准备数据集: ```python from sklearn import svm from sklearn.datasets import load_iris from sklearn.model_selection import train_test_split from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay import matplotlib.pyplot as plt ``` 接着定义函数来训练SVM模型并对测试集做出预测: ```python def train_and_predict(X_train, y_train, X_test): clf = svm.SVC(kernel='linear', C=1, decision_function_shape='ovo') # 使用一对一策略处理多类问题 clf.fit(X_train, y_train) predictions = clf.predict(X_test) return predictions ``` 之后创建用于展示混淆矩阵的辅助函数: ```python def plot_confusion_mat(y_true, y_pred, class_names): cm = confusion_matrix(y_true, y_pred) disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names) fig, ax = plt.subplots(figsize=(8, 6)) disp.plot(ax=ax, cmap=plt.cm.Blues) plt.title('Confusion Matrix') plt.show() ``` 最后组合上述组件完成整个流程: ```python # 加载鸢尾花数据集作为例子 data = load_iris() X = data.data y = data.target class_names = data.target_names # 划分训练集和测试集 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42) # 训练模型并获取预测结果 predictions = train_and_predict(X_train, y_train, X_test) # 展示混淆矩阵 plot_confusion_mat(y_test, predictions, class_names) ``` 此过程展示了如何利用`sklearn`库中的工具构建一个多类别分类器,并通过可视化手段评估其性能表现。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值