根据结果绘制混淆矩阵confusion matrix和training/valid loss的python代码
混淆矩阵部分:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import matplotlib
# 字体设置(中文)
matplotlib.rcParams['font.family'] = 'SimHei' # 设置字体为黑体
matplotlib.rcParams['font.size'] = 14
matplotlib.rcParams['axes.unicode_minus'] = False # 正确显示负号
# 混淆矩阵csv的路径
data_path = "./confusion_matrix.csv"
confusion_matrix_data = pd.read_csv(data_path)
confusion_matrix_data.head()
confusion_matrix = confusion_matrix_data.set_index(confusion_matrix_data.columns[0])
labels = confusion_matrix.columns
plt.figure(figsize=(10, 8