Python复现Grad-Cam算法,并保存类激活映射图。
前言
笔者假设看到这篇博文的同学是清楚Grad-Cam原理的。如果你对Grad-Cam不是很清楚,那么我建议你可以先看CAM对应的论文,然后再看Grad-Cam对应的论文,最后结合Grad-Cam论文中的公式来学习并应用下述代码。
退一步讲,如果你不想了解,只是单纯地想用Grad-Cam,那么直接复制下文代码,也是没有问题的,你几乎不需要做太大的改动!
一、根据网络输出的logits及特征图计算输入图像不同标签分量对应的Grad-Cam
建议大家将下述代码存为一个 ops.py文件 。其中,cal_grad_cam函数 根据网络输出的logits及中间特征图 计算输入图像对应 特征图的Grad-Cam。Grad-Cam可以理解为输入图像对应特征图的一个重要性矩阵,经过归一化后的Grad-Cam,其元素值表示了对应特征图位置元素的重要性。
以ResNet50为例,若cal_grad_cam函数的输入分别为网络输出的logits和最后一层卷积输出的特征图pools,那么将返回一个batch_size * 1 * 7 * 7大小的Grad-Cam,分