ROC曲线

这段代码实现了加载预训练模型,对测试数据进行处理,预测并计算每一类的ROC曲线和AUC值。首先,定义了数据预处理步骤,然后实例化模型并加载权重。接着,对每个测试样本进行预测,将预测结果存储。最后,计算并绘制了微观和宏观平均的ROC曲线,保存为ROC.jpg。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

# 引入必要的库
import numpy as np
import matplotlib.pyplot as plt
from itertools import cycle
import os
import torch
import torchvision
from PIL import Image
from numpy import interp
from sklearn.metrics import roc_curve, auc
from sklearn.preprocessing import label_binarize

from torch import nn
os.environ['CUDA_DEVICE_ORDRE'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = '3'

data_transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize((448, 448)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])


# 模型实例化
model_path = r"./net_params/model.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


net = MainNet(3, mode="test")

if os.path.isfile(model_path):
    model_dict = net.state_dict()

    pretrained_dict = torch.load(model_path)

    pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict}

    model_dict.update(pretrained_dict)
    net.load_state_dict(model_dict)
    print("模型加载完成")
net = net.to(device)
net = nn.DataParallel(net)


net.eval()
# 加载数据
lines = open("./test.txt").readlines()
path_list = [line.split("---")[0] for line in lines]
label_list = [int(line.split("---")[1]) for line in lines]

# 将标签二值化
label_list = label_binarize(label_list, classes=[0,1,2,3,4])#classes:列别列表

# 设置种类
n_classes = label_list.shape[1]
# 处理输入图片并收集预测值
final_pre = []
for i, image_path in enumerate(path_list):
    img = Image.open(image_path).convert("RGB")

    img_data = data_transforms(img)
    img_data = img_data.unsqueeze(0)

    pre_out = net(img_data.to(device))

    pre =pre_out[0][0]

    pre =pre.detach().cpu().numpy()
    final_pre.append(pre)
final_pre = np.array(final_pre)


# 计算每一类的ROC
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(n_classes):
    fpr[i], tpr[i], _ = roc_curve(label_list[:, i], final_pre[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])

# Compute micro-average ROC curve and ROC area(方法二)
fpr["micro"], tpr["micro"], _ = roc_curve(label_list.ravel(), final_pre.ravel())
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

# Compute macro-average ROC curve and ROC area(方法一)
# First aggregate all false positive rates
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))

# Then interpolate all ROC curves at this points
mean_tpr = np.zeros_like(all_fpr)
for i in range(n_classes):
    mean_tpr += interp(all_fpr, fpr[i], tpr[i])
# Finally average it and compute AUC
mean_tpr /= n_classes
fpr["macro"] = all_fpr
tpr["macro"] = mean_tpr
roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])

# Plot all ROC curves
lw=2
plt.figure()
plt.plot(fpr["micro"], tpr["micro"],
         label='micro-average ROC curve (area = {0:0.3f})'
               ''.format(roc_auc["micro"]),
         color='black', linestyle='--', linewidth=4)

plt.plot(fpr["macro"], tpr["macro"],
         label='macro-average ROC curve (area = {0:0.3f})'
               ''.format(roc_auc["macro"]),
         color='lightcoral', linestyle='--', linewidth=4)

colors = cycle(['red', 'saddlebrown', 'darkorange',"olive","palegreen"])
#,"cyan","blue","magenta"])
for i, color in zip(range(n_classes), colors):
    plt.plot(fpr[i], tpr[i], color=color, linewidth=lw,
             label='ROC curve of class {0} (area = {1:0.3f})'
             ''.format(i, roc_auc[i]))

plt.plot([0, 1], [0, 1], 'b--', linewidth=2)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')

plt.title('roc')
plt.legend(loc="lower right")#图例位置,右下角

plt.savefig("./ROC.jpg")
plt.show()

请参考:https://blog.youkuaiyun.com/u011047955/article/details/87259052

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值