LPRnet中maxpool3d无法成功转ONNX改写

解决ONNX转换问题:YoloV5+LRPnet中maxpool3d的替换策略
文章介绍了在使用YoloV5结合LRPnet进行道路车牌检测和识别的模型训练后,遇到ONNX不支持maxpool3d算子导致模型导出失败的问题。作者提出了通过替换nn.MaxPool3d为自定义的MaxPool3d_muti_batch模块来解决这个问题,该模块利用MaxPool2d和MaxPool1d组合实现类似的功能。经过验证,替换后的代码能够得到相同的输出,从而成功将模型转换为ONNX格式。

使用yolov5+LRPnet进行道路车牌检测和识别,车牌识别模型训练好后使用torch.onnx.export导出的时候遇到了报错,主要问题就是ONNX不支持maxpool3d算子,无法直接进行转换,本文主要针对这一块进行处理

1、LPRnet主干代码

从LPRnet的主代码中可以得知,其中主要有三个地方用到了nn.MaxPool3d,这是本次模型转onnx无法成功的关键,而nn.MaxPool3d只是求取区域内的最大值,不会造成权重的更新,因此只需要更新代码实现方式,获得相对应的输出即可

class LPRNet(nn.Module):
    def __init__(self, lpr_max_len, phase, class_num, dropout_rate):
        super(LPRNet, self).__init__()
        self.phase = phase
        self.lpr_max_len = lpr_max_len
        self.class_num = class_num
        self.backbone = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1),  # 0  [bs,3,24,94] -> [bs,64,22,92]
            nn.BatchNorm2d(num_features=64),  # 1  -> [bs,64,22,92]
            nn.ReLU(),  # 2  -> [bs,64,22,92]
            nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 1, 1)),  # 3  -> [bs,64,20,90]
            small_basic_block(ch_in=64, ch_out=128),  # 4  -> [bs,128,20,90]
            nn.BatchNorm2d(num_features=128),  # 5  -> [bs,128,20,90]
            nn.ReLU(),  # 6  -> [bs,128,20,90]
            nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(2, 1, 2)),  # 7  -> [bs,64,18,44]
            small_basic_block(ch_in=64, ch_out=256),  # 8  -> [bs,256,18,44]
            nn.BatchNorm2d(num_features=256),  # 9  -> [bs,256,18,44]
            nn.ReLU(),  # 10 -> [bs,256,18,44]
            small_basic_block(ch_in=256, ch_out=256),  # 11 -> [bs,256,18,44]
            nn.BatchNorm2d(num_features=256),  # 12 -> [bs,256,18,44]
            nn.ReLU(),  # 13 -> [bs,256,18,44]
            nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(4, 1, 2)),  # 14 -> [bs,64,16,21]
            nn.Dropout(dropout_rate),  # 0.5 dropout rate                          # 15 -> [bs,64,16,21]
            nn.Conv2d(in_channels=64, out_channels=256, kernel_size=(1, 4), stride=1),  # 16 -> [bs,256,16,18]
            nn.BatchNorm2d(num_features=256),  # 17 -> [bs,256,16,18]
            nn.ReLU(),  # 18 -> [bs,256,16,18]
            nn.Dropout(dropout_rate),  # 0.5 dropout rate    
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

AI小花猫

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值