转yolov5模型时,不转后处理部分

本文讲述了如何在将后处理部分从硬件转移到CPU以支持不被硬件直接处理的情况下的模型转换,涉及TFLite和ONNX版本的TensorFlow模型中Detect类的相应代码修改,以及坐标变换的处理策略。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

一、背景

由于部署在硬件上的时候,后处理部分硬件处理不支持,需要挪到cpu上处理。

二、转int8.tflite版本时挪出后处理部分

  1. 需要修改的文件models/tf.py
    TFDetect(keras.layers.Layer)call()函数修改为下面部分
    def call(self, inputs):
        print('************* deploy ******************')
        z = []  # inference output
        x = []
        for i in range(self.nl):
        	###  原始 ####
            # x.append(self.m[i](inputs[i])) 
            ######## 新增 ########
            if True: 
                temp = self.m[i](inputs[i])
                z.append(tf.reshape(temp,[-1,6]))
                # print('shape', self.m[i](inputs[i]).reshape(-1,6).shape)
            continue
            ######## 新增 ########
            # x(bs,20,20,255) to x(bs,3,20,20,85)
            ny, nx = self.imgsz[0] // self.stride[i], self.imgsz[1] // self.stride[i]
            x[i] = tf.reshape(x[i], [-1, ny * nx, self.na, self.no])

            if not self.training:  # inference
                y = x[i]
                grid = tf.transpose(self.grid[i], [0, 2, 1, 3]) - 0.5
                anchor_grid = tf.transpose(self.anchor_grid[i], [0, 2, 1, 3]) * 4
                xy = (tf.sigmoid(y[..., 0:2]) * 2 + grid) * self.stride[i]  # xy
                wh = tf.sigmoid(y[..., 2:4]) ** 2 * anchor_grid
                # Normalize xywh to 0-1 to reduce calibration error
                xy /= tf.constant([[self.imgsz[1], self.imgsz[0]]], dtype=tf.float32)
                wh /= tf.constant([[self.imgsz[1], self.imgsz[0]]], dtype=tf.float32)
                y = tf.concat([xy, wh, tf.sigmoid(y[..., 4:5 + self.nc]), y[..., 5 + self.nc:]], -1)
                z.append(tf.reshape(y, [-1, self.na * ny * nx, self.no]))
        ######## 新增 ########
        return (tf.concat(z, 0), )
        ######## 新增 ########
        ######## 原始 ########
        # return tf.transpose(x, [0, 2, 1, 3]) if self.training else (tf.concat(z, 1), ) ## org
  1. 模型转换代码
python export.py --weights ckpt/best_620.pt --imgsz 320   --opset 10 --include tflite --int8
  1. 变化如下图,坐标变换部分删掉了(这里使用的是320320的,检测头删掉了4040分辨率部分)
    在这里插入图片描述

二、转onnx版本时挪出后处理部分

  1. 需要修改的文件models/yolo.py
    Detect(nn.Module)forward()函数修改为下面部分
    def forward(self, x):
        z = []  # inference output
        for i in range(self.nl):
            x[i] = self.m[i](x[i])  # conv
           
            ######## 新增 ########
            if True:
                z.append(x[i].reshape(18, -1))
            continue
            ######## 新增 ########
            bs, _, ny, nx = x[i].shape  # x(bs,255,20,20) to x(bs,3,20,20,85)
            x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()

            if not self.training:  # inference
                if self.dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:
                    self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)

                if isinstance(self, Segment):  # (boxes + masks)
                    xy, wh, conf, mask = x[i].split((2, 2, self.nc + 1, self.no - self.nc - 5), 4)
                    xy = (xy.sigmoid() * 2 + self.grid[i]) * self.stride[i]  # xy
                    wh = (wh.sigmoid() * 2) ** 2 * self.anchor_grid[i]  # wh
                    y = torch.cat((xy, wh, conf.sigmoid(), mask), 4)
                else:  # Detect (boxes only)
                    xy, wh, conf = x[i].sigmoid().split((2, 2, self.nc + 1), 4)
                    xy = (xy * 2 + self.grid[i]) * self.stride[i]  # xy
                    wh = (wh * 2) ** 2 * self.anchor_grid[i]  # wh
                    y = torch.cat((xy, wh, conf), 4)
                z.append(y.view(bs, self.na * nx * ny, self.no))
        print('self.export',self.export)

        return  x if self.training else (torch.cat(z, 1),) if self.export else (torch.cat(z, 1), x)

将坐标变换部分跳过,不执行
2. 模型转换代码

python export.py --weights ckpt/best_620.pt --imgsz 320   --opset 10 --include onnx
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值