CAM的全称是 class attenton map,即类别注意力图,他可以告诉我们网络通过图片的那些地方得到了相应分类的得分。下面就是详细的代码
import torch
from torchvision import models
import torch.nn as nn
import torchvision.transforms as tfs
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import cv2
class cal_cam(nn.Module):
def __init__(self, feature_layer="layer4"):
super(cal_cam, self).__init__()
self.model = models.resnet18(pretrained=False)
self.model.load_state_dict(torch.load("G:/工作空间/预训练模型/resnet18-5c106cde.pth"))
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
# 要求梯度的层
self.feature_layer = feature_layer
# 记录梯度
self.gradient = []
# 记录输出的特征图
self.output = []
self.means = [0.485, 0.456, 0.406]
self.stds = [0.229, 0.224, 0.225]