Onnx EfficientNet网络转onnx格式出现的问题记录

本文详细记录了将PyTorch模型转换为ONNX格式过程中遇到的各种问题及解决方案,包括内存溢出、DataParallel使用不当、部分ONNX属性无法导出等常见问题。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

环境: ubuntu 16.04 python 3.5 torch 1.0.1 GTX1080 torchvision0.2.2
拿到同事训练好的torch模型,想要保存为onnx格式,为以后转TensorRT做准备。
代码就几行,torch.onnx定义的方法:

import torch
from torch.autograd import Variable
from efficientnet import efficientnet_b0b
import torch.nn as nn
from collections import OrderedDict

model_path = '0net_params.pkl'
filepath='efficientnet.trt'
num_classes=num_classes=1852
net=efficientnet_b0b((224,224),num_classes=num_classes)

if torch.cuda.is_available():
    net=net.cuda()
    state_dict=torch.load(model_path)
    new_state_dict=OrderedDict()
    for k,v in state_dict.items():
        name=k[7:]
        new_state_dict[name]=v
net.load_state_dict(new_state_dict)
net.eval()
print(net)

dummy_input=Variable(torch.randn(64,3,224,224,device='cuda'))
with torch.no_grad():
    output=torch.onnx.export(net,dummy_input,filepath,verbose=True)

以下是出现的问题,以此记录:

问题1:内存超出问题

最初在没加"with torch.no_grad()"之前,会报错(pytorch新手,顺便积累其他情况):

RuntimeError: CUDA out of memory. Tried to allocate 34.00 MiB (GPU 0; 7.93 GiB total capacity; 6.24 GiB already allocated; 35.19 MiB free; 113.92 MiB cached)

当我把batchsize设置为64的时候会报错,即GPU内存不足。

问题2:nn.DataParallel问题

由于同事给我的torch的模型文件在训练的时候就用到了多GPU,所以最初我的代码,在加载的时候,也要用nn.DataParallel加载,如下所示:

if torch.cuda.is_available():
    net = nn.DataParallel(net)
    net= net.cuda()
net.load_state_dict(torch.load(model_path))

但是,发现,如果用以下方式加载,在最后导出onnx文件时,会报错:

RuntimeError: Error(s) in loading state_dict for EfficientNet:
	Missing key(s) in state_dict: "features.init_block.conv.conv.weight", "features.init_block.conv.bn.bias", "features.init_block.conv.bn.running_var", "features.init_block.conv.bn.weight", "features.init_block.conv.bn.running_mean", "features.stage1.unit1.dw_conv.conv.weight", "features.stage1.unit1.dw_conv.bn.bias", "features.stage1.unit1.dw_conv.bn.running_var", "features.stage1.unit1.dw_conv.bn.weight", "features.stage1.unit1.dw_conv.bn.running_mean", "features.stage1.unit1.se.conv1.bias", "features.stage1.unit1.se.conv1.weight", "features.stage1.unit1.se.conv2.bias", "features.stage1.unit1.se.conv2.weight", "features.stage1.unit1.pw_conv.conv.weight", "features.stage1.unit1.pw_conv.bn.bias", "features.stage1.unit1.pw_conv.bn.running_var", "features.stage1.unit1.pw_conv.bn.weight", "features.stage1.unit1.pw_conv.bn.running_mean", "features.stage2.unit1.conv1.conv.weight", "features.stage2.unit1.conv1.bn.bias", "features.stage2.unit1.conv1.bn.running_var", "features.stage2.unit1.conv1.bn.weight", "features.stage2.unit1.conv1.bn.running_mean", "features.stage2.unit1.conv2.conv.weight", "features.stage2.unit1.conv2.bn.bias", "features.stage2.unit1.conv2.bn.running_var", "features.stage2.unit1.conv2.bn.weight", "features.stage2.unit1.conv2.bn.running_mean", "features.stage2.unit1.se.conv1.bias", "features.stage2.unit1.se.conv1.weight", "features.stage2.unit1.se.conv2.bias", "features.stage2.unit1.se.conv2.weight", "features.stage2.unit1.conv3.conv.weight", "features.stage2.unit1.conv3.bn.bias", "features.stage2.unit1.conv3.bn.running_var", "features.stage2.unit1.conv3.bn.weight", "features.stage2.unit1.conv3.bn.running_mean", "features.stage2.unit2.conv1.conv.weight", "features.stage2.unit2.conv1.bn.bias", "features.stage2.unit2.conv1.bn.running_var", "features.stage2.unit2.conv1.bn.weight", "features.stage2.unit2.conv1.bn.running_mean", "features.stage2.unit2.conv2.conv.weight", "features.stage2.unit2.conv2.bn.bias", "features.stage2.unit2.conv2.bn.running_var", "features.stage2.unit2.conv2.bn.weight", "features.stage2.unit2.conv2.bn.running_mean", "features.stage2.unit2.se.conv1.bias", "features.stage2.unit2.se.conv1.weight", "features.stage2.unit2.se.conv2.bias", "features.stage2.unit2.se.conv2.weight", "features.stage2.unit2.conv3.conv.weight", "features.stage2.unit2.conv3.bn.bias", "features.stage2.unit2.conv3.bn.running_var", "features.stage2.unit2.conv3.bn.weight", "features.stage2.unit2.conv3.bn.running_mean", "features.stage3.unit1.conv1.conv.weight", "features.stage3.unit1.conv1.bn.bias", "features.stage3.unit1.conv1.bn.running_var", "features.stage3.unit1.conv1.bn.weight", "features.stage3.unit1.conv1.bn.running_mean", "features.stage3.unit1.conv2.conv.weight", "features.stage3.unit1.conv2.bn.bias", "features.stage3.unit1.conv2.bn.running_var", "features.stage3.unit1.conv2.bn.weight", "features.stage3.unit1.conv2.bn.running_mean", "features.stage3.unit1.se.conv1.bias", "features.stage3.unit1.se.conv1.weight", "features.stage3.unit1.se.conv2.bias", "features.stage3.unit1.se.conv2.weight", "features.stage3.unit1.conv3.conv.weight", "features.stage3.unit1.conv3.bn.bias", "features.stage3.unit1.conv3.bn.running_var", "features.stage3.unit1.conv3.bn.weight", "features.stage3.unit1.conv3.bn.running_mean", "features.stage3.unit2.conv1.conv.weight", "features.stage3.unit2.conv1.bn.bias", "features.stage3.unit2.conv1.bn.running_var", "features.stage3.unit2.conv1.bn.weight", "features.stage3.unit2.conv1.bn.running_mean", "features.stage3.unit2.conv2.conv.weight", "features.stage3.unit2.conv2.bn.bias", "features.stage3.unit2.conv2.bn.running_var", "features.stage3.unit2.conv2.bn.weight", "features.stage3.unit2.conv2.bn.running_mean", "features.stage3.unit2.se.conv1.bias", "features.stage3.unit2.se.conv1.weight", "features.stage3.unit2.se.conv2.bias", "features.stage3.unit2.se.conv2.weight", "features.stage3.unit2.conv3.conv.weight", "features.stage3.unit2.conv3.bn.bias", "features.stage3.unit2.conv3.bn.running_var", "features.stage3.unit2.conv3.bn.weight", "features.stage3.unit2.conv3.bn.running_mean", "features.stage4.unit1.conv1.conv.weight", "features.stage4.unit1.conv1.bn.bias", "features.stage4.unit1.conv1.bn.running_var", "features.stage4.unit1.conv1.bn.weight", "features.stage4.unit1.conv1.bn.running_mean", "features.stage4.unit1.conv2.conv.weight", "features.stage4.unit1.conv2.bn.bias", "features.stage4.unit1.conv2.bn.running_var", "features.stage4.unit1.conv2.bn.weight", "features.stage4.unit1.conv2.bn.running_mean", "features.stage4.unit1.se.conv1.bias", "features.stage4.unit1.se.conv1.weight", "features.stage4.unit1.se.conv2.bias", "features.stage4.unit1.se.conv2.weight", "features.stage4.unit1.conv3.conv.weight", "features.stage4.unit1.conv3.bn.bias", "features.stage4.unit1.conv3.bn.running_var", "features.stage4.unit1.conv3.bn.weight", "features.stage4.unit1.conv3.bn.running_mean", "features.stage4.unit2.conv1.conv.weight", "features.stage4.unit2.conv1.bn.bias", "features.stage4.unit2.conv1.bn.running_var", "features.stage4.unit2.conv1.bn.weight", "features.stage4.unit2.conv1.bn.running_mean", "features.stage4.unit2.conv2.conv.weight", "features.stage4.unit2.conv2.bn.bias", "features.stage4.unit2.conv2.bn.running_var", "features.stage4.unit2.conv2.bn.weight", "features.stage4.unit2.conv2.bn.running_mean", "features.stage4.unit2.se.conv1.bias", "features.stage4.unit2.se.conv1.weight", "features.stage4.unit2.se.conv2.bias", "features.stage4.unit2.se.conv2.weight", "features.stage4.unit2.conv3.conv.weight", "features.stage4.unit2.conv3.bn.bias", "features.stage4.unit2.conv3.bn.running_var", "features.stage4.unit2.conv3.bn.weight", "features.stage4.unit2.conv3.bn.running_mean", "features.stage4.unit3.conv1.conv.weight", "features.stage4.unit3.conv1.bn.bias", "features.stage4.unit3.conv1.bn.running_var", "features.stage4.unit3.conv1.bn.weight", "features.stage4.unit3.conv1.bn.running_mean", "features.stage4.unit3.conv2.conv.weight", "features.stage4.unit3.conv2.bn.bias", "features.stage4.unit3.conv2.bn.running_var", "features.stage4.unit3.conv2.bn.weight", "features.stage4.unit3.conv2.bn.running_mean", "features.stage4.unit3.se.conv1.bias", "features.stage4.unit3.se.conv1.weight", "features.stage4.unit3.se.conv2.bias", "features.stage4.unit3.se.conv2.weight", "features.stage4.unit3.conv3.conv.weight", "features.stage4.unit3.conv3.bn.bias", "features.stage4.unit3.conv3.bn.running_var", "features.stage4.unit3.conv3.bn.weight", "features.stage4.unit3.conv3.bn.running_mean", "features.stage4.unit4.conv1.conv.weight", "features.stage4.unit4.conv1.bn.bias", "features.stage4.unit4.conv1.bn.running_var", "features.stage4.unit4.conv1.bn.weight", "features.stage4.unit4.conv1.bn.running_mean", "features.stage4.unit4.conv2.conv.weight", "features.stage4.unit4.conv2.bn.bias", "features.stage4.unit4.conv2.bn.running_var", "features.stage4.unit4.conv2.bn.weight", "features.stage4.unit4.conv2.bn.running_mean", "features.stage4.unit4.se.conv1.bias", "features.stage4.unit4.se.conv1.weight", "features.stage4.unit4.se.conv2.bias", "features.stage4.unit4.se.conv2.weight", "features.stage4.unit4.conv3.conv.weight", "features.stage4.unit4.conv3.bn.bias", "features.stage4.unit4.conv3.bn.running_var", "features.stage4.unit4.conv3.bn.weight", "features.stage4.unit4.conv3.bn.running_mean", "features.stage4.unit5.conv1.conv.weight", "features.stage4.unit5.conv1.bn.bias", "features.stage4.unit5.conv1.bn.running_var", "features.stage4.unit5.conv1.bn.weight", "features.stage4.unit5.conv1.bn.running_mean", "features.stage4.unit5.conv2.conv.weight", "features.stage4.unit5.conv2.bn.bias", "features.stage4.unit5.conv2.bn.running_var", "features.stage4.unit5.conv2.bn.weight", "features.stage4.unit5.conv2.bn.running_mean", "features.stage4.unit5.se.conv1.bias", "features.stage4.unit5.se.conv1.weight", "features.stage4.unit5.se.conv2.bias", "features.stage4.unit5.se.conv2.weight", "features.stage4.unit5.conv3.conv.weight", "features.stage4.unit5.conv3.bn.bias", "features.stage4.unit5.conv3.bn.running_var", "features.stage4.unit5.conv3.bn.weight", "features.stage4.unit5.conv3.bn.running_mean", "features.stage4.unit6.conv1.conv.weight", "features.stage4.unit6.conv1.bn.bias", "features.stage4.unit6.conv1.bn.running_var", "features.stage4.unit6.conv1.bn.weight", "features.stage4.unit6.conv1.bn.running_mean", "features.stage4.unit6.conv2.conv.weight", "features.stage4.unit6.conv2.bn.bias", "features.stage4.unit6.conv2.bn.running_var", "features.stage4.unit6.conv2.bn.weight", "features.stage4.unit6.conv2.bn.running_mean", "features.stage4.unit6.se.conv1.bias", "features.stage4.unit6.se.conv1.weight", "features.stage4.unit6.se.conv2.bias", "features.stage4.unit6.se.conv2.weight", "features.stage4.unit6.conv3.conv.weight", "features.stage4.unit6.conv3.bn.bias", "features.stage4.unit6.conv3.bn.running_var", "features.stage4.unit6.conv3.bn.weight", "features.stage4.unit6.conv3.bn.running_mean", "features.stage5.unit1.conv1.conv.weight", "features.stage5.unit1.conv1.bn.bias", "features.stage5.unit1.conv1.bn.running_var", "features.stage5.unit1.conv1.bn.weight", "features.stage5.unit1.conv1.bn.running_mean", "features.stage5.unit1.conv2.conv.weight", "features.stage5.unit1.conv2.bn.bias", "features.stage5.unit1.conv2.bn.running_var", "features.stage5.unit1.conv2.bn.weight", "features.stage5.unit1.conv2.bn.running_mean", "features.stage5.unit1.se.conv1.bias", "features.stage5.unit1.se.conv1.weight", "features.stage5.unit1.se.conv2.bias", "features.stage5.unit1.se.conv2.weight", "features.stage5.unit1.conv3.conv.weight", "features.stage5.unit1.conv3.bn.bias", "features.stage5.unit1.conv3.bn.running_var", "features.stage5.unit1.conv3.bn.weight", "features.stage5.unit1.conv3.bn.running_mean", "features.stage5.unit2.conv1.conv.weight", "features.stage5.unit2.conv1.bn.bias", "features.stage5.unit2.conv1.bn.running_var", "features.stage5.unit2.conv1.bn.weight", "features.stage5.unit2.conv1.bn.running_mean", "features.stage5.unit2.conv2.conv.weight", "features.stage5.unit2.conv2.bn.bias", "features.stage5.unit2.conv2.bn.running_var", "features.stage5.unit2.conv2.bn.weight", "features.stage5.unit2.conv2.bn.running_mean", "features.stage5.unit2.se.conv1.bias", "features.stage5.unit2.se.conv1.weight", "features.stage5.unit2.se.conv2.bias", "features.stage5.unit2.se.conv2.weight", "features.stage5.unit2.conv3.conv.weight", "features.stage5.unit2.conv3.bn.bias", "features.stage5.unit2.conv3.bn.running_var", "features.stage5.unit2.conv3.bn.weight", "features.stage5.unit2.conv3.bn.running_mean", "features.stage5.unit3.conv1.conv.weight", "features.stage5.unit3.conv1.bn.bias", "features.stage5.unit3.conv1.bn.running_var", "features.stage5.unit3.conv1.bn.weight", "features.stage5.unit3.conv1.bn.running_mean", "features.stage5.unit3.conv2.conv.weight", "features.stage5.unit3.conv2.bn.bias", "features.stage5.unit3.conv2.bn.running_var", "features.stage5.unit3.conv2.bn.weight", "features.stage5.unit3.conv2.bn.running_mean", "features.stage5.unit3.se.conv1.bias", "features.stage5.unit3.se.conv1.weight", "features.stage5.unit3.se.conv2.bias", "features.stage5.unit3.se.conv2.weight", "features.stage5.unit3.conv3.conv.weight", "features.stage5.unit3.conv3.bn.bias", "features.stage5.unit3.conv3.bn.running_var", "features.stage5.unit3.conv3.bn.weight", "features.stage5.unit3.conv3.bn.running_mean", "features.stage5.unit4.conv1.conv.weight", "features.stage5.unit4.conv1.bn.bias", "features.stage5.unit4.conv1.bn.running_var", "features.stage5.unit4.conv1.bn.weight", "features.stage5.unit4.conv1.bn.running_mean", "features.stage5.unit4.conv2.conv.weight", "features.stage5.unit4.conv2.bn.bias", "features.stage5.unit4.conv2.bn.running_var", "features.stage5.unit4.conv2.bn.weight", "features.stage5.unit4.conv2.bn.running_mean", "features.stage5.unit4.se.conv1.bias", "features.stage5.unit4.se.conv1.weight", "features.stage5.unit4.se.conv2.bias", "features.stage5.unit4.se.conv2.weight", "features.stage5.unit4.conv3.conv.weight", "features.stage5.unit4.conv3.bn.bias", "features.stage5.unit4.conv3.bn.running_var", "features.stage5.unit4.conv3.bn.weight", "features.stage5.unit4.conv3.bn.running_mean", "features.stage5.unit5.conv1.conv.weight", "features.stage5.unit5.conv1.bn.bias", "features.stage5.unit5.conv1.bn.running_var", "features.stage5.unit5.conv1.bn.weight", "features.stage5.unit5.conv1.bn.running_mean", "features.stage5.unit5.conv2.conv.weight", "features.stage5.unit5.conv2.bn.bias", "features.stage5.unit5.conv2.bn.running_var", "features.stage5.unit5.conv2.bn.weight", "features.stage5.unit5.conv2.bn.running_mean", "features.stage5.unit5.se.conv1.bias", "features.stage5.unit5.se.conv1.weight", "features.stage5.unit5.se.conv2.bias", "features.stage5.unit5.se.conv2.weight", "features.stage5.unit5.conv3.conv.weight", "features.stage5.unit5.conv3.bn.bias", "features.stage5.unit5.conv3.bn.running_var", "features.stage5.unit5.conv3.bn.weight", "features.stage5.unit5.conv3.bn.running_mean", "features.final_block.conv.weight", "features.final_block.bn.bias", "features.final_block.bn.running_var", "features.final_block.bn.weight", "features.final_block.bn.running_mean", "output.fc.bias", "output.fc.weight". 
	Unexpected key(s) in state_dict: "module.features.init_block.conv.conv.weight", "module.features.init_block.conv.bn.weight", "module.features.init_block.conv.bn.bias", "module.features.init_block.conv.bn.running_mean", "module.features.init_block.conv.bn.running_var", "module.features.init_block.conv.bn.num_batches_tracked", "module.features.stage1.unit1.dw_conv.conv.weight", "module.features.stage1.unit1.dw_conv.bn.weight", "module.features.stage1.unit1.dw_conv.bn.bias", "module.features.stage1.unit1.dw_conv.bn.running_mean", "module.features.stage1.unit1.dw_conv.bn.running_var", "module.features.stage1.unit1.dw_conv.bn.num_batches_tracked", "module.features.stage1.unit1.se.conv1.weight", "module.features.stage1.unit1.se.conv1.bias", "module.features.stage1.unit1.se.conv2.weight", "module.features.stage1.unit1.se.conv2.bias", "module.features.stage1.unit1.pw_conv.conv.weight", "module.features.stage1.unit1.pw_conv.bn.weight", "module.features.stage1.unit1.pw_conv.bn.bias", "module.features.stage1.unit1.pw_conv.bn.running_mean", "module.features.stage1.unit1.pw_conv.bn.running_var", "module.features.stage1.unit1.pw_conv.bn.num_batches_tracked", "module.features.stage2.unit1.conv1.conv.weight", "module.features.stage2.unit1.conv1.bn.weight", "module.features.stage2.unit1.conv1.bn.bias", "module.features.stage2.unit1.conv1.bn.running_mean", "module.features.stage2.unit1.conv1.bn.running_var", "module.features.stage2.unit1.conv1.bn.num_batches_tracked", "module.features.stage2.unit1.conv2.conv.weight", "module.features.stage2.unit1.conv2.bn.weight", "module.features.stage2.unit1.conv2.bn.bias", "module.features.stage2.unit1.conv2.bn.running_mean", "module.features.stage2.unit1.conv2.bn.running_var", "module.features.stage2.unit1.conv2.bn.num_batches_tracked", "module.features.stage2.unit1.se.conv1.weight", "module.features.stage2.unit1.se.conv1.bias", "module.features.stage2.unit1.se.conv2.weight", "module.features.stage2.unit1.se.conv2.bias", "module.features.stage2.unit1.conv3.conv.weight", "module.features.stage2.unit1.conv3.bn.weight", "module.features.stage2.unit1.conv3.bn.bias", "module.features.stage2.unit1.conv3.bn.running_mean", "module.features.stage2.unit1.conv3.bn.running_var", "module.features.stage2.unit1.conv3.bn.num_batches_tracked", "module.features.stage2.unit2.conv1.conv.weight", "module.features.stage2.unit2.conv1.bn.weight", "module.features.stage2.unit2.conv1.bn.bias", "module.features.stage2.unit2.conv1.bn.running_mean", "module.features.stage2.unit2.conv1.bn.running_var", "module.features.stage2.unit2.conv1.bn.num_batches_tracked", "module.features.stage2.unit2.conv2.conv.weight", "module.features.stage2.unit2.conv2.bn.weight", "module.features.stage2.unit2.conv2.bn.bias", "module.features.stage2.unit2.conv2.bn.running_mean", "module.features.stage2.unit2.conv2.bn.running_var", "module.features.stage2.unit2.conv2.bn.num_batches_tracked", "module.features.stage2.unit2.se.conv1.weight", "module.features.stage2.unit2.se.conv1.bias", "module.features.stage2.unit2.se.conv2.weight", "module.features.stage2.unit2.se.conv2.bias", "module.features.stage2.unit2.conv3.conv.weight", "module.features.stage2.unit2.conv3.bn.weight", "module.features.stage2.unit2.conv3.bn.bias", "module.features.stage2.unit2.conv3.bn.running_mean", "module.features.stage2.unit2.conv3.bn.running_var", "module.features.stage2.unit2.conv3.bn.num_batches_tracked", "module.features.stage3.unit1.conv1.conv.weight", "module.features.stage3.unit1.conv1.bn.weight", "module.features.stage3.unit1.conv1.bn.bias", "module.features.stage3.unit1.conv1.bn.running_mean", "module.features.stage3.unit1.conv1.bn.running_var", "module.features.stage3.unit1.conv1.bn.num_batches_tracked", "module.features.stage3.unit1.conv2.conv.weight", "module.features.stage3.unit1.conv2.bn.weight", "module.features.stage3.unit1.conv2.bn.bias", "module.features.stage3.unit1.conv2.bn.running_mean", "module.features.stage3.unit1.conv2.bn.running_var", "module.features.stage3.unit1.conv2.bn.num_batches_tracked", "module.features.stage3.unit1.se.conv1.weight", "module.features.stage3.unit1.se.conv1.bias", "module.features.stage3.unit1.se.conv2.weight", "module.features.stage3.unit1.se.conv2.bias", "module.features.stage3.unit1.conv3.conv.weight", "module.features.stage3.unit1.conv3.bn.weight", "module.features.stage3.unit1.conv3.bn.bias", "module.features.stage3.unit1.conv3.bn.running_mean", "module.features.stage3.unit1.conv3.bn.running_var", "module.features.stage3.unit1.conv3.bn.num_batches_tracked", "module.features.stage3.unit2.conv1.conv.weight", "module.features.stage3.unit2.conv1.bn.weight", "module.features.stage3.unit2.conv1.bn.bias", "module.features.stage3.unit2.conv1.bn.running_mean", "module.features.stage3.unit2.conv1.bn.running_var", "module.features.stage3.unit2.conv1.bn.num_batches_tracked", "module.features.stage3.unit2.conv2.conv.weight", "module.features.stage3.unit2.conv2.bn.weight", "module.features.stage3.unit2.conv2.bn.bias", "module.features.stage3.unit2.conv2.bn.running_mean", "module.features.stage3.unit2.conv2.bn.running_var", "module.features.stage3.unit2.conv2.bn.num_batches_tracked", "module.features.stage3.unit2.se.conv1.weight", "module.features.stage3.unit2.se.conv1.bias", "module.features.stage3.unit2.se.conv2.weight", "module.features.stage3.unit2.se.conv2.bias", "module.features.stage3.unit2.conv3.conv.weight", "module.features.stage3.unit2.conv3.bn.weight", "module.features.stage3.unit2.conv3.bn.bias", "module.features.stage3.unit2.conv3.bn.running_mean", "module.features.stage3.unit2.conv3.bn.running_var", "module.features.stage3.unit2.conv3.bn.num_batches_tracked", "module.features.stage4.unit1.conv1.conv.weight", "module.features.stage4.unit1.conv1.bn.weight", "module.features.stage4.unit1.conv1.bn.bias", "module.features.stage4.unit1.conv1.bn.running_mean", "module.features.stage4.unit1.conv1.bn.running_var", "module.features.stage4.unit1.conv1.bn.num_batches_tracked", "module.features.stage4.unit1.conv2.conv.weight", "module.features.stage4.unit1.conv2.bn.weight", "module.features.stage4.unit1.conv2.bn.bias", "module.features.stage4.unit1.conv2.bn.running_mean", "module.features.stage4.unit1.conv2.bn.running_var", "module.features.stage4.unit1.conv2.bn.num_batches_tracked", "module.features.stage4.unit1.se.conv1.weight", "module.features.stage4.unit1.se.conv1.bias", "module.features.stage4.unit1.se.conv2.weight", "module.features.stage4.unit1.se.conv2.bias", "module.features.stage4.unit1.conv3.conv.weight", "module.features.stage4.unit1.conv3.bn.weight", "module.features.stage4.unit1.conv3.bn.bias", "module.features.stage4.unit1.conv3.bn.running_mean", "module.features.stage4.unit1.conv3.bn.running_var", "module.features.stage4.unit1.conv3.bn.num_batches_tracked", "module.features.stage4.unit2.conv1.conv.weight", "module.features.stage4.unit2.conv1.bn.weight", "module.features.stage4.unit2.conv1.bn.bias", "module.features.stage4.unit2.conv1.bn.running_mean", "module.features.stage4.unit2.conv1.bn.running_var", "module.features.stage4.unit2.conv1.bn.num_batches_tracked", "module.features.stage4.unit2.conv2.conv.weight", "module.features.stage4.unit2.conv2.bn.weight", "module.features.stage4.unit2.conv2.bn.bias", "module.features.stage4.unit2.conv2.bn.running_mean", "module.features.stage4.unit2.conv2.bn.running_var", "module.features.stage4.unit2.conv2.bn.num_batches_tracked", "module.features.stage4.unit2.se.conv1.weight", "module.features.stage4.unit2.se.conv1.bias", "module.features.stage4.unit2.se.conv2.weight", "module.features.stage4.unit2.se.conv2.bias", "module.features.stage4.unit2.conv3.conv.weight", "module.features.stage4.unit2.conv3.bn.weight", "module.features.stage4.unit2.conv3.bn.bias", "module.features.stage4.unit2.conv3.bn.running_mean", "module.features.stage4.unit2.conv3.bn.running_var", "module.features.stage4.unit2.conv3.bn.num_batches_tracked", "module.features.stage4.unit3.conv1.conv.weight", "module.features.stage4.unit3.conv1.bn.weight", "module.features.stage4.unit3.conv1.bn.bias", "module.features.stage4.unit3.conv1.bn.running_mean", "module.features.stage4.unit3.conv1.bn.running_var", "module.features.stage4.unit3.conv1.bn.num_batches_tracked", "module.features.stage4.unit3.conv2.conv.weight", "module.features.stage4.unit3.conv2.bn.weight", "module.features.stage4.unit3.conv2.bn.bias", "module.features.stage4.unit3.conv2.bn.running_mean", "module.features.stage4.unit3.conv2.bn.running_var", "module.features.stage4.unit3.conv2.bn.num_batches_tracked", "module.features.stage4.unit3.se.conv1.weight", "module.features.stage4.unit3.se.conv1.bias", "module.features.stage4.unit3.se.conv2.weight", "module.features.stage4.unit3.se.conv2.bias", "module.features.stage4.unit3.conv3.conv.weight", "module.features.stage4.unit3.conv3.bn.weight", "module.features.stage4.unit3.conv3.bn.bias", "module.features.stage4.unit3.conv3.bn.running_mean", "module.features.stage4.unit3.conv3.bn.running_var", "module.features.stage4.unit3.conv3.bn.num_batches_tracked", "module.features.stage4.unit4.conv1.conv.weight", "module.features.stage4.unit4.conv1.bn.weight", "module.features.stage4.unit4.conv1.bn.bias", "module.features.stage4.unit4.conv1.bn.running_mean", "module.features.stage4.unit4.conv1.bn.running_var", "module.features.stage4.unit4.conv1.bn.num_batches_tracked", "module.features.stage4.unit4.conv2.conv.weight", "module.features.stage4.unit4.conv2.bn.weight", "module.features.stage4.unit4.conv2.bn.bias", "module.features.stage4.unit4.conv2.bn.running_mean", "module.features.stage4.unit4.conv2.bn.running_var", "module.features.stage4.unit4.conv2.bn.num_batches_tracked", "module.features.stage4.unit4.se.conv1.weight", "module.features.stage4.unit4.se.conv1.bias", "module.features.stage4.unit4.se.conv2.weight", "module.features.stage4.unit4.se.conv2.bias", "module.features.stage4.unit4.conv3.conv.weight", "module.features.stage4.unit4.conv3.bn.weight", "module.features.stage4.unit4.conv3.bn.bias", "module.features.stage4.unit4.conv3.bn.running_mean", "module.features.stage4.unit4.conv3.bn.running_var", "module.features.stage4.unit4.conv3.bn.num_batches_tracked", "module.features.stage4.unit5.conv1.conv.weight", "module.features.stage4.unit5.conv1.bn.weight", "module.features.stage4.unit5.conv1.bn.bias", "module.features.stage4.unit5.conv1.bn.running_mean", "module.features.stage4.unit5.conv1.bn.running_var", "module.features.stage4.unit5.conv1.bn.num_batches_tracked", "module.features.stage4.unit5.conv2.conv.weight", "module.features.stage4.unit5.conv2.bn.weight", "module.features.stage4.unit5.conv2.bn.bias", "module.features.stage4.unit5.conv2.bn.running_mean", "module.features.stage4.unit5.conv2.bn.running_var", "module.features.stage4.unit5.conv2.bn.num_batches_tracked", "module.features.stage4.unit5.se.conv1.weight", "module.features.stage4.unit5.se.conv1.bias", "module.features.stage4.unit5.se.conv2.weight", "module.features.stage4.unit5.se.conv2.bias", "module.features.stage4.unit5.conv3.conv.weight", "module.features.stage4.unit5.conv3.bn.weight", "module.features.stage4.unit5.conv3.bn.bias", "module.features.stage4.unit5.conv3.bn.running_mean", "module.features.stage4.unit5.conv3.bn.running_var", "module.features.stage4.unit5.conv3.bn.num_batches_tracked", "module.features.stage4.unit6.conv1.conv.weight", "module.features.stage4.unit6.conv1.bn.weight", "module.features.stage4.unit6.conv1.bn.bias", "module.features.stage4.unit6.conv1.bn.running_mean", "module.features.stage4.unit6.conv1.bn.running_var", "module.features.stage4.unit6.conv1.bn.num_batches_tracked", "module.features.stage4.unit6.conv2.conv.weight", "module.features.stage4.unit6.conv2.bn.weight", "module.features.stage4.unit6.conv2.bn.bias", "module.features.stage4.unit6.conv2.bn.running_mean", "module.features.stage4.unit6.conv2.bn.running_var", "module.features.stage4.unit6.conv2.bn.num_batches_tracked", "module.features.stage4.unit6.se.conv1.weight", "module.features.stage4.unit6.se.conv1.bias", "module.features.stage4.unit6.se.conv2.weight", "module.features.stage4.unit6.se.conv2.bias", "module.features.stage4.unit6.conv3.conv.weight", "module.features.stage4.unit6.conv3.bn.weight", "module.features.stage4.unit6.conv3.bn.bias", "module.features.stage4.unit6.conv3.bn.running_mean", "module.features.stage4.unit6.conv3.bn.running_var", "module.features.stage4.unit6.conv3.bn.num_batches_tracked", "module.features.stage5.unit1.conv1.conv.weight", "module.features.stage5.unit1.conv1.bn.weight", "module.features.stage5.unit1.conv1.bn.bias", "module.features.stage5.unit1.conv1.bn.running_mean", "module.features.stage5.unit1.conv1.bn.running_var", "module.features.stage5.unit1.conv1.bn.num_batches_tracked", "module.features.stage5.unit1.conv2.conv.weight", "module.features.stage5.unit1.conv2.bn.weight", "module.features.stage5.unit1.conv2.bn.bias", "module.features.stage5.unit1.conv2.bn.running_mean", "module.features.stage5.unit1.conv2.bn.running_var", "module.features.stage5.unit1.conv2.bn.num_batches_tracked", "module.features.stage5.unit1.se.conv1.weight", "module.features.stage5.unit1.se.conv1.bias", "module.features.stage5.unit1.se.conv2.weight", "module.features.stage5.unit1.se.conv2.bias", "module.features.stage5.unit1.conv3.conv.weight", "module.features.stage5.unit1.conv3.bn.weight", "module.features.stage5.unit1.conv3.bn.bias", "module.features.stage5.unit1.conv3.bn.running_mean", "module.features.stage5.unit1.conv3.bn.running_var", "module.features.stage5.unit1.conv3.bn.num_batches_tracked", "module.features.stage5.unit2.conv1.conv.weight", "module.features.stage5.unit2.conv1.bn.weight", "module.features.stage5.unit2.conv1.bn.bias", "module.features.stage5.unit2.conv1.bn.running_mean", "module.features.stage5.unit2.conv1.bn.running_var", "module.features.stage5.unit2.conv1.bn.num_batches_tracked", "module.features.stage5.unit2.conv2.conv.weight", "module.features.stage5.unit2.conv2.bn.weight", "module.features.stage5.unit2.conv2.bn.bias", "module.features.stage5.unit2.conv2.bn.running_mean", "module.features.stage5.unit2.conv2.bn.running_var", "module.features.stage5.unit2.conv2.bn.num_batches_tracked", "module.features.stage5.unit2.se.conv1.weight", "module.features.stage5.unit2.se.conv1.bias", "module.features.stage5.unit2.se.conv2.weight", "module.features.stage5.unit2.se.conv2.bias", "module.features.stage5.unit2.conv3.conv.weight", "module.features.stage5.unit2.conv3.bn.weight", "module.features.stage5.unit2.conv3.bn.bias", "module.features.stage5.unit2.conv3.bn.running_mean", "module.features.stage5.unit2.conv3.bn.running_var", "module.features.stage5.unit2.conv3.bn.num_batches_tracked", "module.features.stage5.unit3.conv1.conv.weight", "module.features.stage5.unit3.conv1.bn.weight", "module.features.stage5.unit3.conv1.bn.bias", "module.features.stage5.unit3.conv1.bn.running_mean", "module.features.stage5.unit3.conv1.bn.running_var", "module.features.stage5.unit3.conv1.bn.num_batches_tracked", "module.features.stage5.unit3.conv2.conv.weight", "module.features.stage5.unit3.conv2.bn.weight", "module.features.stage5.unit3.conv2.bn.bias", "module.features.stage5.unit3.conv2.bn.running_mean", "module.features.stage5.unit3.conv2.bn.running_var", "module.features.stage5.unit3.conv2.bn.num_batches_tracked", "module.features.stage5.unit3.se.conv1.weight", "module.features.stage5.unit3.se.conv1.bias", "module.features.stage5.unit3.se.conv2.weight", "module.features.stage5.unit3.se.conv2.bias", "module.features.stage5.unit3.conv3.conv.weight", "module.features.stage5.unit3.conv3.bn.weight", "module.features.stage5.unit3.conv3.bn.bias", "module.features.stage5.unit3.conv3.bn.running_mean", "module.features.stage5.unit3.conv3.bn.running_var", "module.features.stage5.unit3.conv3.bn.num_batches_tracked", "module.features.stage5.unit4.conv1.conv.weight", "module.features.stage5.unit4.conv1.bn.weight", "module.features.stage5.unit4.conv1.bn.bias", "module.features.stage5.unit4.conv1.bn.running_mean", "module.features.stage5.unit4.conv1.bn.running_var", "module.features.stage5.unit4.conv1.bn.num_batches_tracked", "module.features.stage5.unit4.conv2.conv.weight", "module.features.stage5.unit4.conv2.bn.weight", "module.features.stage5.unit4.conv2.bn.bias", "module.features.stage5.unit4.conv2.bn.running_mean", "module.features.stage5.unit4.conv2.bn.running_var", "module.features.stage5.unit4.conv2.bn.num_batches_tracked", "module.features.stage5.unit4.se.conv1.weight", "module.features.stage5.unit4.se.conv1.bias", "module.features.stage5.unit4.se.conv2.weight", "module.features.stage5.unit4.se.conv2.bias", "module.features.stage5.unit4.conv3.conv.weight", "module.features.stage5.unit4.conv3.bn.weight", "module.features.stage5.unit4.conv3.bn.bias", "module.features.stage5.unit4.conv3.bn.running_mean", "module.features.stage5.unit4.conv3.bn.running_var", "module.features.stage5.unit4.conv3.bn.num_batches_tracked", "module.features.stage5.unit5.conv1.conv.weight", "module.features.stage5.unit5.conv1.bn.weight", "module.features.stage5.unit5.conv1.bn.bias", "module.features.stage5.unit5.conv1.bn.running_mean", "module.features.stage5.unit5.conv1.bn.running_var", "module.features.stage5.unit5.conv1.bn.num_batches_tracked", "module.features.stage5.unit5.conv2.conv.weight", "module.features.stage5.unit5.conv2.bn.weight", "module.features.stage5.unit5.conv2.bn.bias", "module.features.stage5.unit5.conv2.bn.running_mean", "module.features.stage5.unit5.conv2.bn.running_var", "module.features.stage5.unit5.conv2.bn.num_batches_tracked", "module.features.stage5.unit5.se.conv1.weight", "module.features.stage5.unit5.se.conv1.bias", "module.features.stage5.unit5.se.conv2.weight", "module.features.stage5.unit5.se.conv2.bias", "module.features.stage5.unit5.conv3.conv.weight", "module.features.stage5.unit5.conv3.bn.weight", "module.features.stage5.unit5.conv3.bn.bias", "module.features.stage5.unit5.conv3.bn.running_mean", "module.features.stage5.unit5.conv3.bn.running_var", "module.features.stage5.unit5.conv3.bn.num_batches_tracked", "module.features.final_block.conv.weight", "module.features.final_block.bn.weight", "module.features.final_block.bn.bias", "module.features.final_block.bn.running_mean", "module.features.final_block.bn.running_var", "module.features.final_block.bn.num_batches_tracked", "module.output.fc.weight", "module.output.fc.bias". 

仔细看可以看到,用DataParallel保存的模型key值前面会多处’modules.’,而我们需要的不带,就导致生成onnx的方法无法读到对应的属性,以此,我参考了这篇博文的方法:解决pytorch中DataParallel后模型参数出现问题的方法
即自己设置一个有序字典,有选择的读取模型数据,然后保存成torch.onnx可以读取的格式,即可,具体正确的代码如问题一上面写的那样。

问题3:有些onnx的属性无法export(Failed to export an ONNX attribute)

RuntimeError: Failed to export an ONNX attribute, since it's not constant, please try to make things (e.g., kernel size) static if possible
网上说,是pytorch版本问题,应该从1.1.0降到1.0.1。(使用清华源出现错误详见问题4)

问题4:pip外链报错:Cannot determine archive format of /tmp/pip-fd2724ty-build

用清华源的时候,出现错误
在这里插入图片描述查询资料发现,外链(清华源)下载软件是要使用新命令的:

sudo pip3 install --index-url https://pypi.tuna.tsinghua.edu.cn/simple/ torch==1.0.1

在这里插入图片描述就ok了。

问题5:安装torch1.0.1之后,发现自动变成1.1.0

发现如问题4的图,明明装的torch1.0.1,装上的却是1.1.0

>>> import torch
>>> torch.__version__
'1.1.0'
>>> 

于是怀疑是在torch1.0.1之后安装torchvision的问题,安装torchvision的时候没有控制版本。查了一下,果然torchvision更新torch
然后重装了torch1.0.1,然后指定版本装了torchvision0.2.2。

问题6:RuntimeError: ONNX symbolic expected a constant value in the trace

在百度上找不到解答,暂时翻不出去。
绝望中用debug看了一下运行的源码,发现在我错的地方,value的值如下

value={Value} 391 defined in (%391 : int[] = prim::ListConstruct(%377, %382, %385, %390), scope: EfficientNet/Sequential[features]/EffiInitBlock[init_block])

可知391层用到了prim::ListConstruct这个结构,但是很明显,torch的意思因为它不是onnx::Constant格式的,就要报错了。

<think>好的,我现在需要帮用户解决如何可视化DR-HybridNet模型架构图的问题。用户提到了使用TensorBoard或Netron,还要展示各个组件,包括输入层、EfficientNet-B0主干、SE注意力模块、Monte Carlo Dropout的不确定性量化,以及输出层的五分类概率分布和置信区间。同时需要Matplotlib的示例代码。 首先,我得回想一下相关的工具和方法。Netron是一个模型可视化工具,支持ONNX格式,所以可能需要将模型导出为ONNX,然后用Netron打开。而TensorBoard则需要模型在训练过程中记录图结构,不过可能对于复杂的自定义模块,需要确保正确的记录。 接下来,用户提到的模型结构各部分。主干网络EfficientNet-B0,预训练权重的加载应该在代码中体现,比如使用PyTorch的预训练模型。注意力模块是Squeeze-and-Excitation,这部分需要确认如何在可视化中突出显示,可能需要自定义命名或注释。Monte Carlo Dropout在推断时保持激活,这部分在导出模型时可能需要特别注意,因为默认的导出可能会关闭Dropout,所以需要确保在导出ONNX时training模式开启,或者手动修改模型结构。 输出层的可视化包括概率分布柱状图和置信区间。Matplotlib的代码示例需要生成这些图,可能需要模拟数据,比如多次前向传播(Monte Carlo采样)来计算均值和置信区间。 用户可能的需求不仅仅是生成结构图,还包括各个模块的可视化方法。例如,SE模块的热力图可能需要特征图的激活值,这可能需要hook来捕获中间层的输出,并在TensorBoard中显示为图像。不过用户可能更倾向于静态的可视化,所以可能需要其他方法,比如在推理时提取特征并生成热力图。 另外,关于不确定性量化,Monte Carlo Dropout需要多次采样,计算预测的方差或置信区间。这部分需要在代码中实现,可能需要在模型推断时保持Dropout层激活,并进行多次前向传播,然后统计结果。 现在,我需要整理步骤: 1. 模型构建:定义DR-HybridNet,包括EfficientNet-B0主干,添加SE模块,插入Monte Carlo Dropout层,并确保在导出和推断时正确处理这些层。 2. 导出为ONNX:使用torch.onnx.export,注意设置training模式,以确保Dropout在导出时保留。 3. 使用Netron查看模型结构,确认各层是否正确。 4. 使用TensorBoard记录模型图,可能需要通过add_graph方法,但自定义模块可能需要更明确的处理。 5. 可视化SE模块的热力图:可能在训练时记录激活值,或者在后处理中提取特征图并生成热力图。 6. 概率分布和置信区间:用Matplotlib绘制,需要模拟多次预测,计算均值和置信区间,然后生成柱状图和误差条。 可能的代码部分包括模型定义、导出ONNX、TensorBoard的记录,以及Matplotlib的绘图。需要检查模型定义中是否有命名清晰的层,方便在可视化中识别。例如,SE模块和Dropout层的命名要明确,这样在Netron或TensorBoard中容易辨认。 另外,用户可能需要示例代码,所以需要提供PyTorch模型定义的片段,导出ONNX的代码,TensorBoard的配置,以及Matplotlib绘图的代码片段。需要注意代码的正确性,比如在导出ONNX时,输入示例的形状是否正确,是否需要动态轴等。 还需要注意,Monte Carlo Dropout在推断时通常需要model.train()模式,以保持Dropout激活。这在导出ONNX时可能需要特别注意,可能需要设置training=torch.onnx.TrainingMode.TRAINING,或者在模型定义中处理。 关于热力图的可视化,SE模块的权重可能需要提取出来,进行归一化后显示为热力图。例如,SE模块中的squeeze操作后的全连接层输出,可以反映通道的重要性,将其换为热力图。 最后,确保各部分的可视化方法在用户的技术栈中可行,比如用户是否熟悉TensorBoard的使用,是否需要额外的安装步骤,比如Netron的安装和使用。</think>以下是为您设计的DR-HybridNet可视化方案,包含代码实现和工具使用建议: ### 一、模型结构可视化方案 1. **Netron快速查看** ```python # 导出ONNX格式(注意保留Dropout) model = DR_HybridNet(pretrained=True) dummy_input = torch.randn(1, 3, 512, 512) torch.onnx.export(model, dummy_input, "DR-HybridNet.onnx", training=torch.onnx.TrainingMode.TRAINING) # 保持Dropout激活[^1] # 启动Netron查看 import netron netron.start('DR-HybridNet.onnx') ``` 2. **TensorBoard详细结构** ```python from torch.utils.tensorboard import SummaryWriter model = DR_HybridNet(pretrained=True) writer = SummaryWriter() dummy_input = torch.rand(1, 3, 512, 512) writer.add_graph(model, dummy_input) writer.close() ``` 在终端执行:`tensorboard --logdir=runs` ### 二、关键模块可视化实现 1. **SE注意力热力图** ```python # 注册hook获取特征图 se_activations = [] def hook_fn(module, input, output): se_activations.append(output.detach().cpu().numpy()) model.backbone[2].se_module.register_forward_hook(hook_fn) # 假设SE模块在backbone第2层 # 可视化 plt.matshow(se_activations[0][0].mean(axis=0), cmap='viridis') plt.colorbar() plt.title("SE Block Channel Weights") ``` 2. **Monte Carlo Dropout不确定性** ```python def mc_predict(image, n_samples=50): model.train() # 保持Dropout激活 outputs = [model(image) for _ in range(n_samples)] return torch.stack(outputs) probs = torch.softmax(mc_predict(test_image), dim=1) mean_probs = probs.mean(axis=0) confidence = 1.96 * probs.std(axis=0) # 95%置信区间 ``` 3. **概率分布可视化** ```python labels = ['Class1', 'Class2', 'Class3', 'Class4', 'Class5'] positions = np.arange(len(labels)) plt.figure(figsize=(10,6)) plt.bar(positions, mean_probs.numpy(), yerr=confidence.numpy(), alpha=0.7, capsize=10, color=['r','g','b','c','m']) plt.xticks(positions, labels) plt.ylabel('Probability') plt.title('Classification Probability with 95% Confidence Interval') plt.ylim(0, 1) ``` ### 三、模型定义参考框架 ```python class DR_HybridNet(nn.Module): def __init__(self, pretrained=True): super().__init__() # 主干网络 self.backbone = EfficientNet.from_pretrained('efficientnet-b0') # 修改分类头 self.dropout = nn.Dropout(0.5) # Monte Carlo Dropout self.classifier = nn.Linear(1280, 5) def forward(self, x): # 特征提取 x = self.backbone.extract_features(x) # 注意力模块已集成在EfficientNet中 x = nn.functional.adaptive_avg_pool2d(x, 1).squeeze(-1).squeeze(-1) # 不确定性量化 x = self.dropout(x) return self.classifier(x) ``` ### 四、可视化效果说明 1. **Netron**:展示完整模型结构层次,包括: - 输入层:`3x512x512`张量 - EfficientNet-B0的7个stage结构 - SE模块的通道注意力机制 - Dropout层标注(MC模式) - 五维输出层 2. **TensorBoard**:交互式查看各层维度变化: - 输入层 → 主干网络特征维度变化曲线 - SE模块的通道缩放操作可视化 - Dropout层的激活状态标记
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值