from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
c=confusion_matrix(y_true,y_pred,labels=['1','2'])
sns.heatmap(c,cmp='Bluels',annot=True,fmt='.2f)
plt.show()
confusion_matrix()函数计算混淆矩阵,输入分别是真实标签值,预测标签值,labels可以自己设计,也可以默认。
注意:标签值是一维数组。
sns.heatmap()绘制热力图。cmp是颜色,有很多种颜色可以选择。annot=True表示在方框中显示数字,fmt可以设置数字的格式,保留小数点几位。还有ax可以设置文字属性等,这部分可以参考链接2
最后记得用plt.show()显示。
参考链接:
混淆矩阵-confusion_matrix(),注意链接第一个例子答案有问题