import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import itertools
import cv2
def load_labels(label_file):
idx_to_labels = np.load(label_file, allow_pickle=True).item()
classes = list(idx_to_labels.values())
return idx_to_labels, classes
def load_predictions(predictions_file):
df = pd.read_csv(predictions_file)
return df
def generate_confusion_matrix(true_labels, predicted_labels, classes):
confusion_matrix_model = confusion_matrix(true_labels, predicted_labels, labels=classes)
return confusion_matrix_model
def plot_confusion_matrix(cm, classes, cmap=plt.cm.Blues):
plt.figure(figsize=(6, 6))
plt.imshow(cm, interpolation='nearest', cmap=cmap)
tick_marks = np.arange(len(classes))
plt.title('Confusion Matrix', fontsize=12)
plt.xlabel('Prediction', fontsize=12, c='r')
plt.ylabel('True', fontsize=12, c='r')
plt.tick_params(labelsize=12)
plt.xticks(tick_marks, classes, rotation=90)
plt.yticks(tick_marks, classes)
threshold = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, cm[i, j],
horizontalalignment="center",
color="white" if cm[i, j] > threshold else "black",
fontsize=12)
plt.tight_layout()
plt.savefig('混淆矩阵.pdf', dpi=300)
plt.show()
def find_misclassified_images(df, true_class, predicted_class):
wrong_df = df[(df['标注类别名称'] == true_class) & (df['top-1-预测名称'] == predicted_class)]
return wrong_df
def visualize_misclassified_images(wrong_df):
for idx, row in wrong_df.iterrows():
img_path = row['图像路径']
img_bgr = cv2.imread(img_path)
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
plt.imshow(img_rgb)
title_str = row['标注类别名称'] + ' Pred:' + row['top-1-预测名称']
plt.title(title_str)
plt.show()
if __name__ == '__main__':
label_file = 'idx_to_labels.npy'
predictions_file = '测试集预测结果.csv'
idx_to_labels, classes = load_labels(label_file)
df = load_predictions(predictions_file)
confusion_matrix_model = generate_confusion_matrix(df['标注类别名称'], df['top-1-预测名称'], classes)
plot_confusion_matrix(confusion_matrix_model, classes, cmap='Blues')
true_class = 'daisy'
predicted_class = 'dandelion'
wrong_df = find_misclassified_images(df, true_class, predicted_class)
print('误判:', wrong_df)
visualize_misclassified_images(wrong_df)
制作混淆矩阵
于 2023-11-28 19:34:23 首次发布