- 回调函数
- lambda函数
- hook函数的模块钩子和张量钩子
- Grad-CAM的示例
作业:理解下今天的代码即可
前置知识
lambda匿名函数
lambda函数是一种没有正式名称的函数,特点是:用极简的语法快速定义临时函数(用完即弃),一般只有一行。
定义的格式,lambda 参数列表:表达式
- 参数列表:单个参数、多个参数或无参数
- 表达式:函数的返回值,无需return,直接返回
调用匿名函数的方式与普通函数(def方法定义)相同:
f = lambda: "Hello, world!"
print(f()) # 输出: Hello, world!
x = lambda a, b : a * b
print(x(5, 6)) # 输出 11
回调函数(Callback Function)
回调函数:将一个函数作为参数传递给另一个函数,并在特定时机由后者“调用”前者。
使用回调函数的意义:
- 解耦逻辑:将通用逻辑与特定处理逻辑分离,使代码更模块化。
- 事件驱动编程:在异步操作、事件监听(如点击按钮、网络请求完成)等场景中广泛应用。
- 延迟执行:允许在未来某个时间点执行特定代码,而不必立即执行。
# 回调函数
def greet(name):
print(f"Hello, {name}!")
def process_user(callback, name): # callback为常用的参数名
print("Processing user...")
callback(name) # 调用回调函数
process_user(greet, "Alice") # greet函数作为参数传入process_user函数
# 装饰器
def greet(name):
print(f"Hello, {name}!")
# 定义装饰器
def with_callback(callback,name): # 装饰器工厂:返回装饰器的函数
def decorator(func): # 真正的装饰器
def wrapper(): # 包装后的函数,替换原函数
func() # 原函数操作
callback(name) # 新增操作,回调函数
return wrapper
return decorator # 返回装饰器
@with_callback(greet,'Alice') # 普通装饰器不能直接接受额外参数
def process_user(): # callback为常用的参数名
print("Processing user...")
process_user()
虽然回调函数类似于装饰器,也是函数里套函数,但是它们存在区别:
| 维度 | 回调函数(Callback) | 装饰器(Decorator) |
|---|---|---|
| 核心 | 被动响应:传递函数作为参数,等待触发 | 主动改造:用新函数包装原函数,修改行为 |
| 调用方式 | 由调用方显式传入,并在内部主动调用 | 通过 @ 语法在定义时绑定,自动替换原函数 |
| 调用时机 | 通常在运行时根据逻辑决定是否/何时调用 | 在模块加载或函数定义时完成包装(静态织入) |
| 语法形式 | 普通函数调用:func(callback=cb) | 专用语法糖:@decorator 写在函数定义上方 |
| 参数传递 | 回调的参数通常由调用方在触发时动态决定 | 装饰器可带参数(需工厂函数),但被装饰函数签名一般不变 |
| 典型场景 | 异步操作、事件处理、策略模式 | 日志记录、权限校验、缓存、计时、重试机制等 |
Hook函数
Hook函数是回调函数的一种结构化、规范化应用,能在特定时机插入自定义逻辑。它的核心思想是:系统或框架在执行流程中的某些“钩点”(hook point)主动调用预先注册的函数,从而实现行为扩展,而无需修改原有代码。
| 机制 | 控制权 | 典型用途 | 是否需显式传参 |
|---|---|---|---|
| 回调函数 | 调用方传入,被调方执行 | 异步结果处理、策略注入 | 是 |
| 装饰器 | 定义时静态绑定 | 统一增强函数(日志、权限等) | 否(通过语法糖) |
| Hook 函数 | 向框架注册,框架在预设点调用 | 插件系统、生命周期扩展 | 否(自动触发) |
pytorch中的Hook机制:
- 注册一个 Hook 时,PyTorch 会在计算图的特定节点(如模块或张量)上添加一个回调函数。
- 当计算图执行到该节点时(前向或反向传播),自动触发对应的 Hook 函数。
- Hook 函数可以访问或修改流经该节点的数据(如输入、输出或梯度)。
在pytorch中主要有两种hook:模块 hook 和 张量hook。
Module Hooks
用于监听整个模块的输入和输出。根据位置,分为前向传播Hook和反向传播Hook。
Forward Hook
register_forward_hook:在模块的前向传播完成后立即被调用 ,实现模块输入和输出的访问。
流程说明:
- 定义模型
- 自定义钩子函数的逻辑
- 传入自定义函数,注册hook(注意位置)
- 前向传播,然后触发hook
- 清理hook(必选):防止Hook持续引用对象,导致内存泄漏等问题
注:为避免内存泄漏,使用detach()追踪梯度以及remove()移除hook。
# 创建一个列表用于存储中间层的输出
conv_outputs = []
# 定义前向钩子函数 - 用于在模型前向传播过程中获取中间层信息
def forward_hook(module, input, output):
"""
前向钩子函数,会在模块每次执行前向传播后被自动调用
参数:
module: 当前应用钩子的模块实例
input: 传递给该模块的输入张量元组
output: 该模块产生的输出张量
"""
print(f"钩子被调用!模块类型: {type(module)}")
print(f"输入形状: {input[0].shape}") # input是一个元组,对应 (image, label)
print(f"输出形状: {output.shape}")
# 保存卷积层的输出用于后续分析
# 使用detach()避免追踪梯度,防止内存泄漏
conv_outputs.append(output.detach())
# 在卷积层注册前向钩子
# register_forward_hook返回一个句柄,用于后续移除钩子
hook_handle = model.conv.register_forward_hook(forward_hook)
# 创建一个随机输入张量 (批次大小=1, 通道=1, 高度=4, 宽度=4)
x = torch.randn(1, 1, 4, 4)
output = model(x) # 执行前向传播 - 此时会自动触发钩子函数
# 释放钩子 - 重要!防止在后续模型使用中持续调用钩子造成意外行为或内存泄漏
hook_handle.remove()
通过Hook访问输入和输出后,可视化卷积层的输出:
# 可视化卷积层的输出
if conv_outputs:
plt.figure(figsize=(10, 5))
# 原始输入图像
plt.subplot(1, 3, 1)
plt.title('输入图像')
plt.imshow(x[0, 0].detach().numpy(), cmap='gray') # 显示灰度图像
# 第一个卷积核的输出
plt.subplot(1, 3, 2)
plt.title('卷积核1输出')
plt.imshow(conv_outputs[0][0, 0].detach().numpy(), cmap='gray')
# 第二个卷积核的输出
plt.subplot(1, 3, 3)
plt.title('卷积核2输出')
plt.imshow(conv_outputs[0][0, 1].detach().numpy(), cmap='gray')
plt.tight_layout()
plt.show()
Backward Hook
register_backward_hook:在反向传播过程中被调用,可以用来获取或修改梯度信息。
基本类似前向传播的hook:
# 定义一个存储梯度的列表
conv_gradients = []
# 定义反向钩子函数
def backward_hook(module, grad_input, grad_output):
# 模块:当前应用钩子的模块
# grad_input:模块输入的梯度
# grad_output:模块输出的梯度
print(f"反向钩子被调用!模块类型: {type(module)}")
print(f"输入梯度数量: {len(grad_input)}")
print(f"输出梯度数量: {len(grad_output)}")
# 保存梯度供后续分析
conv_gradients.append((grad_input, grad_output))
# 在卷积层注册反向钩子
hook_handle = model.conv.register_backward_hook(backward_hook)
# 创建一个随机输入并进行前向传播
x = torch.randn(1, 1, 4, 4, requires_grad=True)
output = model(x)
# 定义一个简单的损失函数并进行反向传播
loss = output.sum()
loss.backward()
# 释放钩子
hook_handle.remove()
Tensor Hooks
- register_hook:监听张量的梯度。
- register_full_backward_hook:用于在完整的反向传播过程中监听张量的梯度
# 创建一个需要计算梯度的张量
x = torch.tensor([2.0], requires_grad=True)
y = x ** 2
z = y ** 3
# 定义一个钩子函数,用于修改梯度
def tensor_hook(grad):
print(f"原始梯度: {grad}")
# 修改梯度,例如将梯度减半
return grad / 2
hook_handle = y.register_hook(tensor_hook) # 在y上注册钩子
z.backward() # 计算梯度
print(f"x的梯度: {x.grad}") # 最终梯度
# 释放钩子
hook_handle.remove()
反向传播流程:
z.backward()
↓
计算 ∂z/∂y = 48.0
↓
Hook被调用!tensor_hook(48.0)
↓ 返回修改后的梯度 24.0
计算 ∂z/∂x = 24.0 × ∂y/∂x
↓
∂y/∂x = 2x = 4.0
↓
最终梯度: 24.0 × 4.0 = 96.0
Grad-CAM可视化
Grad-CAM(Gradient-weighted Class Activation Mapping,梯度加权类激活映射)是一种可视化技术,用于理解CNN模型在做出预测时关注图像的哪些区域。
核心思想:通过反向传播得到的梯度信息,生成热力图,来衡量每个特征图对目标类别的重要性。

Grad-CAM的完整定义
# Grad-CAM实现
class GradCAM:
def __init__(self, model, target_layer):
self.model = model
self.target_layer = target_layer # 目标层
self.gradients = None # 存储梯度
self.activations = None # 存储激活值
# 注册钩子,用于获取目标层的前向传播输出和反向传播梯度
self.register_hooks()
def register_hooks(self):
# 前向钩子函数,在目标层前向传播后被调用,保存目标层的输出(激活值)
def forward_hook(module, input, output):
self.activations = output.detach() # 保存前向传播的输出(特征图)
# 反向钩子函数,在目标层反向传播后被调用,保存目标层的梯度
def backward_hook(module, grad_input, grad_output):
self.gradients = grad_output[0].detach() # 保存反向传播的梯度
# 在目标层注册前向钩子和反向钩子
self.target_layer.register_forward_hook(forward_hook)
self.target_layer.register_backward_hook(backward_hook)
def generate_cam(self, input_image, target_class=None):
# 前向传播,得到模型输出
model_output = self.model(input_image)
if target_class is None:
# 如果未指定目标类别,则取模型预测概率最大的类别作为目标类别
target_class = torch.argmax(model_output, dim=1).item()
# 清除模型梯度,避免之前的梯度影响
self.model.zero_grad()
# 反向传播,构造one-hot向量,使得目标类别对应的梯度为1,其余为0,然后进行反向传播计算梯度
one_hot = torch.zeros_like(model_output)
one_hot[0, target_class] = 1
model_output.backward(gradient=one_hot)
# 获取之前保存的目标层的梯度和激活值
gradients = self.gradients
activations = self.activations
# 对梯度进行全局平均池化,得到每个通道的权重,用于衡量每个通道的重要性
weights = torch.mean(gradients, dim=(2, 3), keepdim=True)
# 加权激活映射,将权重与激活值相乘并求和,得到类激活映射的初步结果
cam = torch.sum(weights * activations, dim=1, keepdim=True)
# ReLU激活,只保留对目标类别有正贡献的区域,去除负贡献的影响
cam = F.relu(cam)
# 调整大小并归一化,将类激活映射调整为与输入图像相同的尺寸(32x32),并归一化到[0, 1]范围
cam = F.interpolate(cam, size=(32, 32), mode='bilinear', align_corners=False)
cam = cam - cam.min()
cam = cam / cam.max() if cam.max() > 0 else cam
return cam.cpu().squeeze().numpy(), target_class
可以看到代码包含三个部分:
- 初始化:变量存储、hook注册(保证在卷积层创建之前)
- Hook定义:获取前向传播输出的activation和反向传播输出的gradients(后得到weights)
- 生成热图:计算加权映射(weight*activation)、激活、调大小与归一化等
对于生成热图部分的执行流程如下:
输入图像 (32×32)
↓
前向传播 → 模型预测: [狗:0.1, 猫:0.8, 鸟:0.1]
↓
选择目标类别: "猫" (类别1)
↓
构造one-hot: [0, 1, 0]
↓
反向传播 → 计算"猫"类别对特征图的梯度
↓
获取:
- 特征图: 512个7×7的激活图 (哪里被激活)
- 梯度: 512个7×7的梯度图 (多重要)
↓
计算通道权重: 对每个通道的梯度求平均
↓
生成CAM: 权重 × 特征图 → 求和 → 7×7的热力图
↓
后处理: ReLU → 上采样到32×32 → 归一化
↓
输出: 显示模型识别"猫"时关注的区域
可视化
叠加图像的处理

三个的图绘制(原始图像+热力图+叠加图):
import warnings
warnings.filterwarnings("ignore")
import matplotlib.pyplot as plt
# 设置中文字体支持
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
# 选择一个随机图像
# idx = np.random.randint(len(testset))
idx = 102 # 选择测试集中的第101张图片 (索引从0开始)
image, label = testset[idx]
print(f"选择的图像类别: {classes[label]}")
# 转换图像以便可视化
def tensor_to_np(tensor):
img = tensor.cpu().numpy().transpose(1, 2, 0)
mean = np.array([0.5, 0.5, 0.5])
std = np.array([0.5, 0.5, 0.5])
img = std * img + mean
img = np.clip(img, 0, 1) # 进行数值裁剪,保证所有值在[0,1]之间
return img
# 添加批次维度并移动到设备
input_tensor = image.unsqueeze(0).to(device)
# 初始化Grad-CAM(选择最后一个卷积层)
grad_cam = GradCAM(model, model.conv3)
# 生成热力图
heatmap, pred_class = grad_cam.generate_cam(input_tensor)
# 可视化
plt.figure(figsize=(12, 4))
# 原始图像
plt.subplot(1, 3, 1)
plt.imshow(tensor_to_np(image))
plt.title(f"原始图像: {classes[label]}")
plt.axis('off')
# 热力图
plt.subplot(1, 3, 2)
plt.imshow(heatmap, cmap='jet')
plt.title(f"Grad-CAM热力图: {classes[pred_class]}")
plt.axis('off')
# 叠加的图像
plt.subplot(1, 3, 3)
img = tensor_to_np(image)
heatmap_resized = np.uint8(255 * heatmap)
heatmap_colored = plt.cm.jet(heatmap_resized)[:, :, :3]
superimposed_img = heatmap_colored * 0.4 + img * 0.6
plt.imshow(superimposed_img)
plt.title("叠加热力图")
plt.axis('off')
plt.tight_layout()
plt.savefig('grad_cam_result.png')
plt.show()
# print("Grad-CAM可视化完成。已保存为grad_cam_result.png")


被折叠的 条评论
为什么被折叠?



