一、直方图可视化数据分布
1. 知识介绍
在PyTorch 模型的每一层注册一个 forward hook,从而能够捕获每层的输出
简单列表存储形式(只能顺序查看每层输出,下文会有改进版用字典将层名字和层输出值对应)
activations = []
def hook_fn(module, input, output):
activations.append(output.detach()) # 记录输出并去除计算图
# 注册 hook
for layer in model.children():
layer.register_forward_hook(hook_fn)
#列表存储版本 可视化
for i, activation in enumerate(activations):
plt.figure(figsize=(6, 4))
plt.hist(activation.numpy().flatten(), bins=100)
plt.title(f'Layer {i + 1} Activation Distribution')
plt.xlabel('Activation Value')
plt.ylabel('Frequency')
plt.show()
1. 什么是 Hook?
在 PyTorch 中,hook 是一种机制,允许你在模型的前向传播(
forward
)或反向传播(backward
)过程中插入自定义操作。具体来说,forward hook 允许你在每一层的输出被计算之后,执行一些额外的操作,比如记录激活值、修改输出等。2.model.children() 的作用
model.children() 是一个 PyTorch 中的迭代器,用来返回模型中的每个子模块(子层)。对于一个 nn.Module(即神经网络模型),它的子模块可以是层(例如 nn.Linear,nn.Conv2d)或者其他子网络(例如子模型)。model.children()返回的是这些子模块的迭代器,它可以让你遍历模型中的每一层。
3. register_forward_hook的作用
register_forward_hook是一个方法,它可以在你对模型进行前向传播时,将一个 hook 函数(
hook_fn
)注册到每一层。这个函数会在每次经过该层时被调用,从而允许你对输出数据进行进一步的处理、修改或记录。
2. 代码实战
2.1 关键步骤
1. 捕获激活值:
- 使用 forward_hook 在前向传播时捕获每一层的激活值,并存储在字典 activations 中,键为层的名字,值为层的输出。
def hook_fn(module, input, output):
for name, layer in model.named_children():
if layer == module: # 如果当前层匹配
activations[name] = output.detach() # 将输出存储到字典中
- 注册 forward_hook:使用 register_forward_hook 方法