# 引入必要的库
import numpy as np
import matplotlib.pyplot as plt
from itertools import cycle
import os
import torch
import torchvision
from PIL import Image
from numpy import interp
from sklearn.metrics import roc_curve, auc
from sklearn.preprocessing import label_binarize
from torch import nn
os.environ['CUDA_DEVICE_ORDRE'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = '3'
data_transforms = torchvision.transforms.Compose([
torchvision.transforms.Resize((448, 448)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
# 模型实例化
model_path = r"./net_params/model.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = MainNet(3, mode="test")
if os.path.isfile(model_path):
model_dict = net.state_dict()
pretrained_dict = torch.load(model_path)
pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict}
model_dict.update(pretrained_dict)
net.load_state_dict(model_dict)
print("模型加载完成")
net = net.to(device)
net = nn.DataParallel(net)
net.eval()
# 加载数据
lines = open("./test.txt").readlines()
path_list = [line.split("---")[0] for line in lines]
label_list = [int(line.split("---")[1]) for line in lines]
# 将标签二值化
label_list = label_binarize(label_list, classes=[0,1,2,3,4])#classes:列别列表
# 设置种类
n_classes = label_list.shape[1]
# 处理输入图片并收集预测值
final_pre = []
for i, image_path in enumerate(path_list):
img = Image.open(image_path).convert("RGB")
img_data = data_transforms(img)
img_data = img_data.unsqueeze(0)
pre_out = net(img_data.to(device))
pre =pre_out[0][0]
pre =pre.detach().cpu().numpy()
final_pre.append(pre)
final_pre = np.array(final_pre)
# 计算每一类的ROC
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(n_classes):
fpr[i], tpr[i], _ = roc_curve(label_list[:, i], final_pre[:, i])
roc_auc[i] = auc(fpr[i], tpr[i])
# Compute micro-average ROC curve and ROC area(方法二)
fpr["micro"], tpr["micro"], _ = roc_curve(label_list.ravel(), final_pre.ravel())
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
# Compute macro-average ROC curve and ROC area(方法一)
# First aggregate all false positive rates
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))
# Then interpolate all ROC curves at this points
mean_tpr = np.zeros_like(all_fpr)
for i in range(n_classes):
mean_tpr += interp(all_fpr, fpr[i], tpr[i])
# Finally average it and compute AUC
mean_tpr /= n_classes
fpr["macro"] = all_fpr
tpr["macro"] = mean_tpr
roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])
# Plot all ROC curves
lw=2
plt.figure()
plt.plot(fpr["micro"], tpr["micro"],
label='micro-average ROC curve (area = {0:0.3f})'
''.format(roc_auc["micro"]),
color='black', linestyle='--', linewidth=4)
plt.plot(fpr["macro"], tpr["macro"],
label='macro-average ROC curve (area = {0:0.3f})'
''.format(roc_auc["macro"]),
color='lightcoral', linestyle='--', linewidth=4)
colors = cycle(['red', 'saddlebrown', 'darkorange',"olive","palegreen"])
#,"cyan","blue","magenta"])
for i, color in zip(range(n_classes), colors):
plt.plot(fpr[i], tpr[i], color=color, linewidth=lw,
label='ROC curve of class {0} (area = {1:0.3f})'
''.format(i, roc_auc[i]))
plt.plot([0, 1], [0, 1], 'b--', linewidth=2)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('roc')
plt.legend(loc="lower right")#图例位置,右下角
plt.savefig("./ROC.jpg")
plt.show()
请参考:https://blog.youkuaiyun.com/u011047955/article/details/87259052