# 导入工具包
import matplotlib.pyplot as plt
from PIL import Image
import torch
from torchcam.methods import LayerCAM
# CAM GradCAM GradCAMpp ISCAM LayerCAM SSCAM ScoreCAM SmoothGradCAMpp XGradCAM
from torchvision import transforms
from torchcam.utils import overlay_mask
import numpy as np
# windows操作系统
plt.rcParams['font.sans-serif']=['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus']=False # 用来正常显示负号
def torchcam_single_ours3(image_path,checkpoint_path,
idx_to_labels_path,labels_to_idx_path,
show_class):
# 有 GPU 就用 GPU,没有就用 CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device', device)
# 导入ImageNet预训练模型
model=torch.load(checkpoint_path)
model = model.eval().to(device)
# 导入可解释性分析方法
cam_extractor = LayerCAM(model)
# 测试集图像预处理-RCTN:缩放、裁剪、转 Tensor、归一化
test_transform = transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
idx_to_labels = np.load(idx_to_labels_path, allow_pickle=True).item()
labels_to_idx = np.load(labels_to_idx_path, allow_pickle=True).item()
# 前向预测
img_pil = Image.open(image_path)
input_tensor = test_transform(img_pil).unsqueeze(0).to(device) # 预处理
pred_logits = model(input_tensor)
pred_id = torch.topk(pred_logits, 1)[1].detach().cpu().numpy().squeeze().item()
if show_class:
class_id = labels_to_idx[show_class]
show_id = class_id
else:
show_id = pred_id
# 生成类激活图(CAM)
activation_map = cam_extractor(show_id, pred_logits) # 使用 cam_extractor 函数生成类激活图,传入预测的类别索引和预测 logits
# 处理激活图结果
activation_map = activation_map[0][0].detach().cpu().numpy() # 从张量中提取类激活图的 NumPy 数组表示
result = overlay_mask(img_pil, Image.fromarray(activation_map),
alpha=0.4) # 用于将类激活图(Activation Map)叠加到原始图像上,以可视化模型对输入图像的关注区域。
# 可视化
plt.subplots(figsize=(6, 6))
plt.imshow(result)
plt.axis('off')
plt.title('Pred:{} Show:{}'.format(idx_to_labels[pred_id], show_class))
plt.show()
if __name__ == '__main__':
image_path='test_img/test_fruits.jpg'
checkpoint_path='checkpoint/fruit30_pytorch_20220814.pth'
idx_to_labels_path='idx_to_labels.npy'
labels_to_idx_path='labels_to_idx.npy'
# 可视化热力图的类别,如果不指定,则为置信度最高的预测类别
show_class = '火龙果'
torchcam_single_ours3(image_path,checkpoint_path,
idx_to_labels_path,labels_to_idx_path,
show_class)
torchcam-single-ours3
于 2023-11-29 22:22:01 首次发布