torchcam-single-ours3

# 导入工具包
import matplotlib.pyplot as plt
from PIL import Image
import torch
from torchcam.methods import LayerCAM
# CAM GradCAM GradCAMpp ISCAM LayerCAM SSCAM ScoreCAM SmoothGradCAMpp XGradCAM
from torchvision import transforms
from torchcam.utils import overlay_mask
import numpy as np
# windows操作系统
plt.rcParams['font.sans-serif']=['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus']=False  # 用来正常显示负号

def torchcam_single_ours3(image_path,checkpoint_path,
                          idx_to_labels_path,labels_to_idx_path,
                          show_class):
    # 有 GPU 就用 GPU,没有就用 CPU
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    print('device', device)
    # 导入ImageNet预训练模型
    model=torch.load(checkpoint_path)
    model = model.eval().to(device)
    # 导入可解释性分析方法
    cam_extractor = LayerCAM(model)
    # 测试集图像预处理-RCTN:缩放、裁剪、转 Tensor、归一化
    test_transform = transforms.Compose([transforms.Resize(256),
                                         transforms.CenterCrop(224),
                                         transforms.ToTensor(),
                                         transforms.Normalize(
                                             mean=[0.485, 0.456, 0.406],
                                             std=[0.229, 0.224, 0.225])])

    idx_to_labels = np.load(idx_to_labels_path, allow_pickle=True).item()
    labels_to_idx = np.load(labels_to_idx_path, allow_pickle=True).item()

    # 前向预测
    img_pil = Image.open(image_path)
    input_tensor = test_transform(img_pil).unsqueeze(0).to(device)  # 预处理
    pred_logits = model(input_tensor)
    pred_id = torch.topk(pred_logits, 1)[1].detach().cpu().numpy().squeeze().item()

    if show_class:
        class_id = labels_to_idx[show_class]
        show_id = class_id
    else:
        show_id = pred_id
    # 生成类激活图(CAM)
    activation_map = cam_extractor(show_id, pred_logits)  # 使用 cam_extractor 函数生成类激活图,传入预测的类别索引和预测 logits
    # 处理激活图结果
    activation_map = activation_map[0][0].detach().cpu().numpy()  # 从张量中提取类激活图的 NumPy 数组表示
    result = overlay_mask(img_pil, Image.fromarray(activation_map),
                          alpha=0.4)  # 用于将类激活图(Activation Map)叠加到原始图像上,以可视化模型对输入图像的关注区域。
    # 可视化
    plt.subplots(figsize=(6, 6))
    plt.imshow(result)
    plt.axis('off')
    plt.title('Pred:{} Show:{}'.format(idx_to_labels[pred_id], show_class))
    plt.show()


if __name__ == '__main__':
    image_path='test_img/test_fruits.jpg'
    checkpoint_path='checkpoint/fruit30_pytorch_20220814.pth'
    idx_to_labels_path='idx_to_labels.npy'
    labels_to_idx_path='labels_to_idx.npy'
    # 可视化热力图的类别,如果不指定,则为置信度最高的预测类别
    show_class = '火龙果'
    torchcam_single_ours3(image_path,checkpoint_path,
                          idx_to_labels_path,labels_to_idx_path,
                          show_class)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Make_magic

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值