一、核心函数
第一行的参数是必须的,第二行的参数是save_image这个函数调用make_grid时用的,make_grid的作用是将多通道的特征拼接成一个网格状排列的大图,make_grid返回的是一个tensor,save_image直接将图片保存在filename路径。
# tensor 的shape要变成[C, 1, H, W]
torchvision.utils.save_image(tensor, filename,
nrow=8, padding=2, normalize=False, range=None, scale_each=False, pad_value=0)
简单的例子
a = torch.randn((1, 32, 256, 512))
file_path = 'result/a.png'
a = a.permute(1, 0, 2, 3)
torchvision.utils.save_image(a, file_path)
二、优雅的使用方式
上述的简单例子可以适用于任意情况,只要写在模型的forward()函数内即可,如果不想写在forward()内,则可以用另外的方式实现,结合torch.nn.Module.register_forward_hook函数。
1. torch.nn.Module.register_forward_hook
函数声明:register_forward_pre_hook(hook: Callable[..., None])
括号中的参数是一个函数名,姑且称之为hook_func,hook_func需要自己实现。
hook_func可从前向过程中接收到三个参数:hook_func(module, input, output)。其中module指的是模块的名称,调用register_forward_hook后会自己绑定,比如对于卷积层,module是Conv2d,注意module带有具体的参数。input和output就是我们所需的特征图,这二者分别是module的输入和输出,输入可能有多个(比如concate层就有多个输入),输出只有一个,所以input是一个tuple,其中每一个元素都是一个tensor,而输出就是一个tensor。一般而言output可能更经常拿来做分析。我们可以在hook_func中将特征图画出来并保存为图片,所以hook_func是实现可视化的关键。
2. 具体使用方法
# 定义全局字典存储features
features = {}
def get_activation(name):
"""
:param name: 特征名字
:return:
"""
def hook_func(module, input, output):
"""
:param module: 调用register_forward_hook后自动绑定
:param input: 输入
:param output: 输出
:return:
"""
# 这里可以写自己所需要的处理方式
features[name] = output.detach()
return hook_func
# 可视化features
def show_features(features):
"""
:param features: 特征字典
:return:
"""
for keys, values in features.items():
# 将特征的shape变成[C, 1, H, W]
values = values.permute(1, 0, 2, 3)
# 保存路径
image_name = 'result/' + keys + '.png'
torchvision.utils.save_image(values, image_name)
if __name__ == '__main__':
# 创建模型
model = nn.DataParallel(.......)
# 为模型的某一层或者某一模块注册hook
model.module.my_block_or_layer.register_forward_hook(get_activation('my_block_or_layer'))
# 调用模型前向传播
.................
# 保存特征图
show_features(features)
特征图可视化结果
