场景分类attention可视化

本文介绍了一种名为Grad-CAM的技术,用于可视化深度学习模型的注意力机制。通过计算特征图与梯度的加权平均,Grad-CAM能生成热力图,显示模型在进行预测时关注的图像区域。此方法有助于理解模型决策过程,识别潜在问题并提供改进方向。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

demo

README

对网络的注意力可视化,可以很快的看出网络存在的问题以及可以改进的空间
refer: https://medium.com/@stepanulyanin/implementing-grad-cam-in-pytorch-ea0937c31e82

attention可视化的思路:

  1. 通过hook获得特征层的grad,维度是(batch, 2048, 8, 8)
  2. 然后对每个channel的grad求平均,维度是(batch, 2048)
  3. 计算前向传播得到的特征图,维度是(batch, 2048, 8, 8)
  4. 将第三步得到的特征图和第二步得到的平均后的grad做个加权求和,维度是(batch, 2048, 8, 8)
  5. 将第四步加权求和后的(batch, 2048, 8, 8)按channel做平均,维度是(batch, 8, 8),然后缩小到[0,1] 就是heatmap

源代码

https://github.com/CarryHJR/remote-sense-quickstart
模型部分承接scene-classification-quickstart.ipynb

定义用于计算卷积的模块

net_single = list(net.children())[0]

from torchvision import models
class ResNetCam(nn.Module):
    def __init__(self):
        super(ResNetCam, self).__init__()
        
        # get the pretrained VGG19 network
        self.resnet = net_single

        
        # disect the network to access its last convolutional layer
        self.features_conv = nn.Sequential(*list(self.resnet.children())[:8])
        
        # get the max pool of the features stem
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        
        # get the classifier of the vgg19
        self.fc = self.resnet.fc
        
        # placeholder for the gradients
        self.gradients = None
    
    # hook for the gradients of the activations
    def activations_hook(self, grad):
        self.gradients = grad
        
    def forward(self, x):
        x = self.features_conv(x)
        print(x.shape)
        
        # register the hook
        h = x.register_hook(self.activations_hook)
        
        # apply the remaining pooling
        x = self.avgpool(x)
        x = x.view((1, -1))
        x = self.fc(x)
        return x
    
    # method for the gradient extraction
    def get_activations_gradient(self):
        return self.gradients
    
    # method for the activation exctraction
    def get_activations(self, x):
        return self.features_conv(x)
resNetCam = ResNetCam()
_ = resNetCam.eval()

attention可视化

img, _ = next(dataiter_val)

# get the most likely prediction of the model
pred = resNetCam(img.cuda())

index = pred.argmax(dim=1).item()

pred[:, index].backward()

# 通过hook获得特征层的grad,维度是(batch, 2048, 8, 8)
gradients = resNetCam.get_activations_gradient()

# 然后对每个channel的grad求平均,维度是(batch, 2048)
pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])

# 计算前向传播得到的特征图,维度是(batch, 2048, 8, 8)
activations = resNetCam.get_activations(img.cuda()).detach()


# 将第三步得到的特征图和第二步得到的平均后的grad做个加权求和,维度是(batch, 2048, 8, 8)
for i in range(activations.shape[1]):
    activations[:, i, :, :] *= pooled_gradients[i]
    
# 将第四步加权求和后的(batch, 2048, 8, 8)按channel做平均,维度是(batch, 8, 8),然后缩小到[0,1] 就是heatmap
heatmap = torch.mean(activations, dim=1).squeeze()

# relu on top of the heatmap
# expression (2) in https://arxiv.org/pdf/1610.02391.pdf
heatmap = heatmap.cpu()
heatmap = np.maximum(heatmap, 0)

# normalize the heatmap
heatmap /= torch.max(heatmap)



image = img[0].numpy().transpose((1, 2, 0))
image = np.array(std) * image + np.array(mean)
image = np.clip(image, 0, 1)
image = np.uint8(255*image)

import cv2
heatmap_resize = cv2.resize(heatmap.numpy(), (image.shape[1], image.shape[0]))
heatmap_resize = np.uint8(255 * heatmap_resize)
heatmap_resize = cv2.applyColorMap(heatmap_resize, cv2.COLORMAP_JET)
superimposed_img = heatmap_resize * 0.4 + image
superimposed_img = superimposed_img / np.max(superimposed_img)

fig, axes = plt.subplots(1,3,figsize=(12,4))
axes[0].imshow(superimposed_img)

tmp = heatmap.squeeze().numpy()
im = axes[1].imshow(tmp, interpolation='nearest')
axes[1].figure.colorbar(im, ax=axes[1], fraction=0.046, pad=0.04)

axes[2].imshow(image)

print(idx_to_class[index])

demo

### 关于Attention Map可视化示例 对于注意力机制中的注意力图(attention map)可视化,可以借助专门设计用于此目的的工具或库。例如,在处理深度学习模型特别是基于Transformer架构时,存在多种方法实现这一功能。 #### 使用TS-CAM进行弱监督目标定位 在特定应用场景下,如弱监督对象定位任务中,一种称为Token Semantic Coupled Attention Map (TS-CAM) 的技术被提出用来增强模型的表现力[^2]。该方法不仅能够生成高质量的关注热图,还能有效地区分不同类别的物体实例。下面给出一段Python代码片段展示如何利用给定的数据集创建此类关注图: ```python import torch from torchvision import models, transforms from PIL import Image import matplotlib.pyplot as plt def show_cam_on_image(img, mask): heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET) heatmap = np.float32(heatmap) / 255 cam = heatmap + np.float32(img) cam = cam / np.max(cam) return np.uint8(255 * cam) # 加载预训练模型并提取最后一层卷积特征 model = models.resnet50(pretrained=True).eval() features_blobs = [] def hook_feature(module, input, output): features_blobs.append(output.data.cpu().numpy()) _ = model._modules.get('layer4').register_forward_hook(hook_feature) normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) preprocess = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), normalize ]) img_pil = Image.open('/path/to/image.jpg') tensor_img = preprocess(img_pil).unsqueeze_(0) output_logits = model(tensor_img) # 获取CAM权重参数 params = list(model.parameters()) weight_softmax_params = params[-2].data.numpy() # 计算CAM bz, nc, h, w = features_blobs[0].shape cam = weight_softmax_params.dot(features_blobs[0].reshape((nc, h*w))) cam = cam.reshape(h, w) cam = cam - np.min(cam) cam_img = cam / np.max(cam) cam_img = np.uint8(255 * cam_img) result = show_cam_on_image(cv2.cvtColor(np.array(img_pil.resize((w,h))), cv2.COLOR_RGB2BGR)/255., cam_img) plt.imshow(result) plt.axis('off') plt.show() ``` 这段代码展示了如何加载一张图像并通过ResNet50网络获取其对应的特征映射,进而构建Class Activation Mapping(CAM),最后将其叠加到原始图像之上形成最终的效果图。注意这里使用的是标准的CAM而非文中提到的TS-CAM;要应用后者,则需按照论文描述调整相应部分逻辑以适应具体需求。 #### 利用Seaborn进行高级可视化 除了上述针对机器学习模型内部状态的具体呈现外,还可以运用像`seaborn`这样的强大绘图库来进行更加通用的数据分析与探索工作。以下是另一个例子,说明了怎样通过FacetGrid函数轻松制作多变量分布对比图表[^3]: ```python import seaborn as sns sns.set_theme(style="whitegrid") diamonds = sns.load_dataset("diamonds") clarity_ranking = ["I1", "SI2", "SI1", "VS2", "VS1", "VVS2", "VVS1", "IF"] g = sns.catplot( data=diamonds.sort_values("price"), x="color", y="price", hue="cut", col="clarity", aspect=.7, kind="swarm" ) g.set_titles(col_template="{col_name} clarity") g.set_xticklabels(rotation=90) plt.tight_layout() plt.show() ``` 以上两段代码分别代表了两种不同的可视化场景——前者专注于解释AI算法的工作原理及其决策依据,而后者则更多地服务于数据分析领域内的常规任务。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值