from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import numpy as np
import torch
def plot_confusion_matrix(cm,savename,classes,title='Confusion Matrix',normalize=True):
plt.figure(figsize=(12,8),dpi=100)
np.set_printoptions(precision=2)
ind_array = np.arange(len(classes))
x,y = np.meshgrid(ind_array,ind_array)
if normalize:
cm = np.array(cm,dtype=float)/np.sum(cm,axis=1)
for x_val,y_val in zip(x.flatten(),y.flatten()):
c = cm[y_val][x_val]
if c>0.001:
plt.text(x_val, y_val, "%0.2f" % (c,), color='black', fontsize=10, va='center', ha='center')
else:
for x_val,y_val in zip(x.flatten(),y.flatten()):
c = cm[y_val][x_val]
if c>0.001:
plt.text(x_val, y_val, "%0.2f" % (c,), color='black', fontsize=10, va='center', ha='center')
plt.imshow(cm,interpolation='nearest',cmap=plt.cm.Blues)
plt.title(title)
plt.colorbar()
xlocations = np.array(range(len(classes)))
plt.xticks(xlocations, classes, rotation=90)
plt.yticks(xlocations, classes)
plt.ylabel('Actual label')
plt.xlabel('Predict label')
tick_marks = np.array(range(len(classes))) + 0.5
plt.gca().set_xticks(tick_marks, minor=True)
plt.gca().set_yticks(tick_marks, minor=True)
plt.gca().xaxis.set_ticks_position('none')
plt.gca().yaxis.set_ticks_position('none')
plt.grid(True, which='minor', linestyle='-')
plt.gcf().subplots_adjust(bottom=0.15)
plt.show()
if __name__ == '__main__':
path="..\\data\\confusion_matrix.npy"
contents = np.load(path,allow_pickle=True)
contents = contents.tolist();
y_true = contents["y_true"]
y_pred= contents["y_pre"]
classes = ['A', 'B', 'C', 'D', 'E', 'F','G','H','I','J']
cm = confusion_matrix(y_true, y_pred)
plot_confusion_matrix(cm,'confusion_matrix.png',classes, title='confusion matrix')