一 tf.confusion_matrix(y,pred_y,num_classes)
以二分类为例,混淆矩阵为2*2的矩阵,如果我们真实标签为real=[0,1,1,0,1],预测标签为predict=[0,1,0,1,1]
num_classes为分类数,一定要设置这个!!!否则默认为None
测试
import numpy as np
import tensorflow as tf
y=np.array([[1,0],[1,0],[1,0],[1,0],[1,0]])
y=tf.convert_to_tensor(y)
predict=np.array([[1,0],[1,0],[1,0],[1,0],[1,0]])
predict=tf.convert_to_tensor(predict)
confusion_matrix=tf.confusion_matrix(tf.argmax(y,1),tf.argmax(predict,1),num_classes=2)
with tf.Session() as sess: #开始一个会话
matrix=sess.run(confusion_matrix)
print(matrix)
#输出[[5 0] [0 0]]
#如果不加num_classes=2就会输出[[5]]
二 评估指标
得到混淆矩阵后,即可得到TP,TN,FP,FN。根据各种指标的公式可求
def evaluate(confusion_metrics):
TP=confusion_metrics[0][0]
FP=confusion_metrics[0][1]
FN=confusion_metrics[1][0]
TN=confusion_metrics[1][1]
ACC=(TP+TN)/(TP+TN+FP+FN)
SEN=TP/(TP+