PyTorch学习记录四——pytorch中的hook方法

Hook函数

[四种钩子方法]https://gitcode.youkuaiyun.com/65ed75221a836825ed799ce7.html?dp_token=eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpZCI6NzczNjcwLCJleHAiOjE3NDA1NjkzNDIsImlhdCI6MTczOTk2NDU0MiwidXNlcm5hbWUiOiJtbW1tZmgifQ.Twr_zsCNcfWd36lAOa5ZT9jNUWGuNaMgBWu9wWfqfr0
[英文GiantPandaCV]https://www.bilibili.com/video/BV1MV411t7td?spm_id_from=333.788.videopod.sections&vd_source=3fa24499c8737a243c0012a2cf75cb9a

  • 在PyTorch中,hook函数是一种强大的工具,允许用户在不修改模型代码的情况下,介入模型的前向传播和反向传播过程。在模型调试、特征提取、梯度分析等任务中非常有用。
  • 比如为了节省显存(内存),pytorch在计算过程中不保存中间变量,包括中间层的特征图和非叶子张量的梯度等。有时对网络进行分析时需要查看或修改这些中间变量,此时就需要注册一个钩子(hook)来导出需要的中间变量。
  • 比如做知识蒸馏时,可能就需要用到钩子函数将部分关键特征层提取出来,使得教师模型和学生模型的对应层进行学习。

四种hook方法

torch.Tensor.register_hook()  # tensor级别
torch.nn.Module.register_forward_hook()  # Module级别
torch.nn.Module.register_backward_hook()
torch.nn.Module.register_forward_pre_hook()

forward hook——register_forward_hook()

forward是在模块的前向传播结束后自动调用的函数。他接受三个参数。

  1. module:当前模块
  2. input:模块的输入
  3. output:模块的输出

从而可以获得中间层的输出、修改输出等。
以下代码展示如何通过forward hook提取卷积层的输出特征:

import torch
import torch.nn as nn

class SimpleCNN(nn.Module):
	def __init__(self):
	supper(SimpleCNN, self).__init__()
	self.conv1 = nn.Cnov2d(3, 16, kernel_size=3)
	self.relu = nn.ReLU()

	def forward(self, x):
	x = self.conv1(x)
	x = self.relu(x)
	return x

def extract_features(module, input, output):
	features = output.detach().cpu().numpy()
	print("Extracted features shape:", features.shape)

model = SimpleCNN()
handle = model.conv1.register_forward_hook(extract_features)
input_tensor = torch.randn(1, 3, 32, 32)
output = model(input_tensor)
handle.remove()  # 移除 hook

在这个例子中,extract_features函数会在conv1层的前向传播结束后被调用,并将该层的输出特征保存。
PS:forward hook的输入是一个元组,输出是一个张量或元组。

Backward hook——register_hook()

backward hook是在模块的反向传播结束后自动调用的函数。他接受一个参数:grad_output(模块或张量的梯度输出)。我们可以通过他获取或修改梯度。
以下代码展示了如何通过backward hook 获取中间变量的梯度:

import torch

def grad_hook(grad):
	y_grad.append(grad)

y_grad = list()
x = torch.tensor([[1., 2.], [3., 4.]], requires_grad=True)
y = x + 1
y.register_hook(grad_hook)
z = torch.mean(y * y)
z.backward()
print("y.grad:", y.grad)  # y不是叶子节点,其grad为none
print("y_grad[0]:", y_grad[0])  # 通过hook获取 y 的梯度

在这个例子中,y不是叶子节点,其梯度在反向传播结束后会被释放,但通过backward hook,我们仍然可以获取其梯度。
PS:backward hook的输入是梯度。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值