从0到1制作单只鳌虾运动轨迹追踪软件

前言

需要准备windows10操作系统,python3.11.9cuDNN8.9.2.26CUDA11.8paddleDetection2.7

流程:

  1. 准备数据集-澳洲鳌虾VOC数据集 
  2. 基于RT-DETR目标检测模型训练
  3. 导出onnx模型进行python部署
  4. 平滑滤波处理视频帧保留的物体质心坐标
  5. 基于pywebview为软件前端,falsk为软件后端制作UI
  6. 使用pyinstaller打包成exe
  7. 使用into setup生成安装包

本人代码禁止任何商业化用途,个人开发者随意。所有代码均开源

项目目录

XXX 项目总目录
    static 存放js静态文件
        plotly.js
    templates 存放html文件
        index.html
    temp 用户上传文件保存路径
    venv 虚拟环境
    main.py 主程序
    model.onnx 模型文件
    1.ico 打包的程序图标

准备数据集

点击下载澳洲鳌虾VOC数据集

下载后解压,文件目录为

data
    Annotations
        0.xml
        1.xml
        ...
    imgs
        0.jpg
        1.jpg
        ...
    lables.txt

然后使用如下的脚本把数据集划分为训练集和测试集

import os
import random
import shutil


def splitDatasets(images_dir,xmls_dir,train_dir,test_dir):

    if os.path.exists(train_dir):
        shutil.rmtree(train_dir)
        
    os.makedirs(train_dir)
    os.makedirs(train_dir+'/imgs')
    os.makedirs(train_dir+'/annotations')
        
    if os.path.exists(test_dir):
        shutil.rmtree(test_dir)
        
    os.makedirs(test_dir)
    os.makedirs(test_dir+'/imgs')
    os.makedirs(test_dir+'/annotations')
        
    images=os.listdir(images_dir)
    random.shuffle(images)

    split_index=int(0.9*len(images))

    train_images=images[:split_index]
    test_images=images[split_index:]

    with open(train_dir+'/train.txt','w') as file:
        for img in train_images:
            shutil.copy(os.path.join(images_dir,img),os.path.join(train_dir,'imgs',img))
            ann=img.replace('jpg','xml')
            shutil.copy(os.path.join(xmls_dir,ann),os.path.join(train_dir,'annotations',ann))
            line=os.path.join(train_dir,'imgs',img)+' '+os.path.join(train_dir,'annotations',ann)+'\n'
            file.write(line)

    with open(test_dir+'/test.txt','w') as file:
        for img in test_images:
            shutil.copy(os.path.join(images_dir,img),os.path.join(test_dir,'imgs',img))
            ann=img.replace('jpg','xml')
            shutil.copy(os.path.join(xmls_dir,ann),os.path.join(test_dir,'annotations',ann))
            line=os.path.join(test_dir,'imgs',img)+' '+os.path.join(test_dir,'annotations',ann)+'\n'
            file.write(line)
        
    shutil.rmtree(images_dir)
    shutil.rmtree(xmls_dir)
    
if __name__=='__main__':
    # 填写img文件夹所在绝对路径
    images_dir='/home/aistudio/work/voc/imgs'
    # 填写Annotations文件夹所在绝对路径
    xmls_dir='/home/aistudio/work/voc/Annotations'
    # 填写 训练集 的存放的绝对路径
    train_dir='/home/aistudio/work/voc/trains'
    # 填写 测试集 的存放的绝对路径
    test_dir='/home/aistudio/work/voc/tests'
    
    splitDatasets(images_dir,xmls_dir,train_dir,test_dir)

训练模型

可在aistudio云平台训练,我放好了所有的相关文件,点击进入,里面的说明很详细

也可在本地进行训练,下面来配置本地的训练环境

配置相关文件

下载paddleDetection2.7

原始目录如下

paddleDetection2.7
    .github
    .travis
    activity
    benchmark
    configs 模型配置文件
    dataset 里面有数据集下载的脚本文件
    demo
    deploy 推理的相关文件
    docs 说明文档
    industrial_tutorial
    ppdet 模型运行的核心文件
    scripts
    test_pic
    tools 模型训练入口,测试,验证,导出等脚本文件
    .gitignore
    .pre-commit-config.yaml
    .style.yapf
    .travis.yml
    LICENSE
    README_cn.md 说明文档中文版
    README_en.md 说明文档英文版
    requirements.txt 相关依赖库
    setup.py 模型编译的相关脚本

需要删除一些目录,把README_en.md改名为README.md,处理过的目录如下

paddleDetection2.7
    configs
    dataset
    deploy
    ppdet
    tools
    README.md
    requirements.txt
    setup.py

把dataset里所有东西都删除,再将划分好的数据集放到该文件下,处理好的目录如下

dataset
    voc
        trains
            annotations
            imgs
            train.txt
        tests
            annotations
            imgs
            test.txt
        labels.txt

进入tools目录,只保留如下文件,其余全删除,处理后的文件目录如下

tools
    train.py
    infer.py
    eval.py
    export_model.py

进入configs目录,只保留下面三个文件和目录,处理后的目录如下

configs
    datasets
    rtdetr
    runtime.yml

进入datasets目录,只保留voc.yml,其余文件全删除,处理后的目录如下

datasets
    voc.yml

并用如下内容覆盖voc.yml

metric: VOC
map_type: 11point
num_classes: 1

TrainDataset:
  name: VOCDataSet
  dataset_dir: dataset/voc
  anno_path: trains/train.txt
  label_list: labels.txt
  data_fields: ['image', 'gt_bbox', 'gt_class', 'difficult']

EvalDataset:
  name: VOCDataSet
  dataset_dir: dataset/voc
  anno_path: tests/test.txt
  label_list: labels.txt
  data_fields: ['image', 'gt_bbox', 'gt_class', 'difficult']

TestDataset:
  name: ImageFolder
  anno_path: dataset/labels.txt

进入rtdetr目录,只保留如下2个文件和目录,处理后的目录如下:

rtdetr
    _base_
    rtdetr_hgnetv2_x_6x_coco.yml

进入_base_目录,找到optimizer_6x.yml,修改第一行为epoch: 200,意思是训练200轮

找到rtdetr_reader.yml,根据自己的CPU和GPU调整相关参数,如果是4核CPU,worker_num可为8,batch_size根据显存调整,占用80%到90%的显存即可

安装依赖库

建议在虚拟环境中操作

!pip install -r requirements.txt
!pip install pycocotools
!pip install filterpy
!pip install flask
!pip install pyinstaller
!pip install pywebview
!pip install onnxruntime-gpu
!pip install onnxruntime
!pip install onnx
!pip install paddle2onnx
!python setup.py install

开始训练

建议命令行输入,先进入paddleDetection所在位置,再执行以下命令

python tools/train.py -c configs/rtdetr/rtdetr_hgnetv2_x_6x_coco.yml --eval --use_vdl True --vdl_log_dir vdl_log_dir/scalar

然后就是漫长的等待

导出模型

生成的模型在paddleDetection/output/best_model/model.pdparams

先进入paddleDetection所在位置,再执行以下命令

python tools/export_model.py -c configs/rtdetr/rtdetr_hgnetv2_x_6x_coco.yml -o weights=output/best_model/model.pdparams

转onnx

先进入paddleDetection所在位置,再执行以下命令,可以根据需要选择保存路径

paddle2onnx --model_dir=output_inference/rtdetr_hgnetv2_x_6x_coco/ \
            --model_filename model.pdmodel  \
            --params_filename model.pdiparams \
            --opset_version 16 \
            --save_file /home/work/infer/model.onnx

模型部署

导包

import webview
from flask import Flask, request, jsonify,render_template,stream_with_context,Response
import os
import time
import cv2
from onnxruntime import InferenceSession
import numpy as np
from werkzeug.utils import secure_filename

总览代码

class TrackShrimp():
    def __init__(self,video_path,model_path,onnx_threshold=0.7):
        # 获取帧数据
        self.cap=self.init_video(video_path)
        frame_width=int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        frame_height=int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        # 图形尺寸
        im_shape = np.array([[frame_height, frame_width]], dtype='float32')
        # y轴缩放量
        self.im_scale_y=640.0/frame_height
        # x轴缩放量
        self.im_scale_x=640.0/frame_width
        scale_factor = np.array([[self.im_scale_y,self.im_scale_x]]).astype('float32')
        # 定义模型输入
        self.inputs_dict = {
            'im_shape': im_shape,
            'image': None,
            'scale_factor': scale_factor
            }
        # 初始化模型
        self.sess=self.init_session(model_path)
        # 模型输出阈值
        self.onnx_threshold=onnx_threshold

    def init_video(self,input_path):
        cap=cv2.VideoCapture(input_path)
        if not cap.isOpened():
            raise ValueError(f'无法打开视频{input_path}')
        return cap

    def init_session(self,model_path):
        try:
            return InferenceSession(model_path, providers=['CUDAExecutionProvider']) 
        except:
            return InferenceSession(model_path, providers=['CPUExecutionProvider'])
        

    def precess_img(self,frame):
        img = cv2.resize(frame, None,None,fx=self.im_scale_x,fy=self.im_scale_y,interpolation=2)
        img = img.astype(np.float32) / 255.0
        img = np.transpose(img, [2, 0, 1])
        img = img[np.newaxis, :, :, :]
        return img
 
    def postcess(self,results:np.ndarray,all_centers:list[np.ndarray]):
        results=results[(results[:, 0] == 0) & (results[:, 1] > self.onnx_threshold)]
        x_centers = (results[:, 2] + results[:, 4]) / 2
        y_centers = (results[:, 3] + results[:, 5]) / 2
        centers = np.column_stack((x_centers, y_centers))
        all_centers.extend(centers)

    def by_smoothfilter(self,centers:list[np.ndarray],window_size=24):
        """
        :param centers: list[np.ndarray,np.ndarray,...]
        :param window_size: 平滑窗口大小
        :return: 平滑后的质心坐标NumPy数组
        """
        centers=np.stack(centers)
        # 计算滑动窗口的平均值,pad函数在序列前后补零以处理边界情况
        padded_centers = np.pad(centers, ((window_size//2, window_size//2), (0, 0)), mode='edge')
        window_sum = np.cumsum(padded_centers, axis=0)
        smoothed_centers = (window_sum[window_size:] - window_sum[:-window_size]) / window_size
        return smoothed_centers

    def calculate_distance(self,centers:np.ndarray):
        '''
        centers:np.ndarray n*2
        '''
        # 计算相邻点之间的差
        diffs = centers[1:] - centers[:-1]
        # 计算每个差值的欧几里得距离
        distances = np.linalg.norm(diffs, axis=1)
        # 计算总路程
        return int(np.sum(distances))

    def gain_position(self,centers:np.ndarray):
        position_list=centers.tolist()
        return position_list
    
    def run(self):
        global schedule
        global run_task
        # 帧数
        frame_count=int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
        frame_number=0
        center_list=[]
        for frame_number in range(frame_count):
            if not run_task:
                return
            success, frame = self.cap.read()
            if not success:
                break
            schedule=int(frame_number/frame_count*100)
            # 打印进度
            if frame_number%10==0:
                print('Process: ',schedule)
            # 图片预处理
            img=self.precess_img(frame)
            self.inputs_dict['image']=img
            results=self.sess.run(None,self.inputs_dict)[0]
            if results is not None:
                self.postcess(results,center_list)
        # 使用平滑滤波
        filtered_centers = self.by_smoothfilter(center_list)
        self.cap.release()

        # 返回路程,轨迹坐标
        return self.calculate_distance(filtered_centers),self.gain_position(filtered_centers)

由于是对视频进行推理,所以首先得初始化视频打开的方法

def init_video(self,input_path):
        cap=cv2.VideoCapture(input_path)
        if not cap.isOpened():
            raise ValueError(f'无法打开视频{input_path}')
        return cap

初始化onnx运行引擎,优先使用显卡,如果CUDA环境有问题,就使用CPU运行

    def init_session(self,model_path):
        try:
            return InferenceSession(model_path, providers=['CUDAExecutionProvider']) 
        except:
            return InferenceSession(model_path, providers=['CPUExecutionProvider'])

onnx引擎需要一定的输入格式,放到类的init里

    def __init__(self,video_path,model_path,onnx_threshold=0.7):
        # 获取帧数据
        self.cap=self.init_video(video_path)
        frame_width=int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        frame_height=int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        # 图形尺寸
        im_shape = np.array([[frame_height, frame_width]], dtype='float32')
        # y轴缩放量
        self.im_scale_y=640.0/frame_height
        # x轴缩放量
        self.im_scale_x=640.0/frame_width
        scale_factor = np.array([[self.im_scale_y,self.im_scale_x]]).astype('float32')
        # 定义模型输入
        self.inputs_dict = {
            'im_shape': im_shape,
            'image': None,
            'scale_factor': scale_factor
            }
        # 初始化模型
        self.sess=self.init_session(model_path)
        # 模型输出阈值
        self.onnx_threshold=onnx_threshold

在提取每一帧后需要进行图像处理,resize图片为模型输入的要求,归一化

    def precess_img(self,frame):
        img = cv2.resize(frame, None,None,fx=self.im_scale_x,fy=self.im_scale_y,interpolation=2)
        img = img.astype(np.float32) / 255.0
        img = np.transpose(img, [2, 0, 1])
        img = img[np.newaxis, :, :, :]
        return img

在提取到视频的每一帧中的鳌虾的质心坐标后,由于每一帧的图像都不一样,输入模型后再输出的结果就不一样,会抖动,也就是噪声,我们需要滤波去噪,这里使用平滑滤波,相比卡尔曼滤波简单使用快速出结果。

    def by_smoothfilter(self,centers:list[np.ndarray],window_size=24):
        """
        :param centers: list[np.ndarray,np.ndarray,...]
        :param window_size: 平滑窗口大小
        :return: 平滑后的质心坐标NumPy数组
        """
        centers=np.stack(centers)
        # 计算滑动窗口的平均值,pad函数在序列前后补零以处理边界情况
        padded_centers = np.pad(centers, ((window_size//2, window_size//2), (0, 0)), mode='edge')
        window_sum = np.cumsum(padded_centers, axis=0)
        smoothed_centers = (window_sum[window_size:] - window_sum[:-window_size]) / window_size
        return smoothed_centers

我们需要计算鳌虾的运动总路程,用滤波后的质心坐标计算

    def calculate_distance(self,centers:np.ndarray):
        '''
        centers:np.ndarray n*2
        '''
        # 计算相邻点之间的差
        diffs = centers[1:] - centers[:-1]
        # 计算每个差值的欧几里得距离
        distances = np.linalg.norm(diffs, axis=1)
        # 计算总路程
        return int(np.sum(distances))

滤波后的质心坐标是numpy数组,需要一定的转换再发送到前端进行渲染(matplotlib画的图太丑了,不如plotly.js)

    def gain_position(self,centers:np.ndarray):
        position_list=centers.tolist()
        return position_list

在获取每一帧图像后,送入模型。模型会输出一对numpy数组,需要进行一对的后处理,低于阈值的就抛弃,然后取阈值最高的,计算质心坐标并保存

    def postcess(self,results:np.ndarray,all_centers:list[np.ndarray]):
        results=results[(results[:, 0] == 0) & (results[:, 1] > self.onnx_threshold)]
        x_centers = (results[:, 2] + results[:, 4]) / 2
        y_centers = (results[:, 3] + results[:, 5]) / 2
        centers = np.column_stack((x_centers, y_centers))
        all_centers.extend(centers)

需要在一个主函数里将上述打开视频,图像预处理,送入模型,后处理连起来

    def run(self):
        global schedule
        global run_task
        # 帧数
        frame_count=int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
        frame_number=0
        center_list=[]
        for frame_number in range(frame_count):
            if not run_task:
                return
            success, frame = self.cap.read()
            if not success:
                break
            schedule=int(frame_number/frame_count*100)
            # 打印进度
            if frame_number%10==0:
                print('Process: ',schedule)
            # 图片预处理
            img=self.precess_img(frame)
            self.inputs_dict['image']=img
            results=self.sess.run(None,self.inputs_dict)[0]
            if results is not None:
                self.postcess(results,center_list)
        # 使用平滑滤波
        filtered_centers = self.by_smoothfilter(center_list)
        self.cap.release()

        # 返回路程,轨迹坐标
        return self.calculate_distance(filtered_centers),self.gain_position(filtered_centers)

前端的设计

以pywebview为平台,html和css设计前端

 

 

 

代码总览

index.html

<!DOCTYPE html>
<html>
<head>
    <title></title>
    <link rel="shortcut icon" href="#" />
    <script src="https://code.jquery.com/jquery-3.6.0.min.js"></script>
    <script src="../static/plotly.js"></script>
    <style>
        html,body{
            width: 100%;
            height: 100%;
            margin: 0 auto;

        }
        body{
            display: flex;
            align-items: center;
            justify-content: center;
            height: 100vh;
            background-color: rgb(6, 32, 80);
        }
        main{
            display: grid;
            grid-template-columns: 1fr 3fr;
            column-gap: 2%;
            width: 98%;
            height: 98%;
        }
        fieldset{
            border: 2px solid rgb(32, 139, 139);
            color: rgb(32, 139, 139);
            margin: 8% 0 8% 0;
        }
        #s2{
            text-align: center;
            display: flex;
            justify-content: center;
            align-items: center;
            background-color: rgba(32, 139, 139, 0.301);
            border: 2px solid rgb(32, 139, 139);
        }
        #progress-circle{
            border: 1em solid rgb(32, 139, 139);
            width: 40vh;
            height: 40vh;
            border-radius: 20vh;
            display: flex;  
            justify-content: center; 
            align-items: center;
        }
        #progress-num{
            font-size: 18vh;
            color: rgb(32, 139, 139);
        }

    </style>
</head>
<body>
    <main>
        <section id="s1">
            <form id="form" enctype="multipart/form-data">
                <fieldset>
                    <legend>选择你要检测的视频</legend>
                    <input type="file" accept=".mp4" id="video" name="vedio">
                </fieldset>
                <fieldset>
                    <legend>功能按键</legend>
                    <button onclick="submit_to()" id="submit">开始上传</button>
                    <button onclick="stopRun()">终止运行</button>
                </fieldset>
            </form>
            <script>
                async function stopRun(){
                    try{
                        const response=await fetch('/stopRun',{method:'POST'})
                        if (!response.ok) {  
                            throw new Error('Network response was not ok.');  
                        }
                        data=await response.json()
                        alert(data.data)
                    }catch(error){
                        console.log(error)
                    }
                }
                
                async function submit_to(){
                    // 防重复激发
                    const button = document.getElementById('submit');  
                    button.disabled = true;
                    try{
                        // 获取文件
                        const input=document.getElementById('video')
                        const file=input.files[0]
                        if (!file){
                            throw new Error('未选择文件')
                        }
                        if(file.type!=='video/mp4'){
                            throw new Error('请选择MP4文件')
                        }
                        // 刷新界面 
                        const s2=document.getElementById('s2')
                        Plotly.purge(s2)
                        // 初始化进度显示
                        const progressCircle=document.getElementById('progress-circle')
                        const progressNum=document.getElementById('progress-num')
                        progressCircle.style.display='flex'
                        progressNum.innerHTML='0%'
                        // 更新进度
                        let source = new EventSource("/progress")
                        source.onmessage = function(event) {
                        progressNum.innerHTML = event.data+'%'
                        }
                        // 发送请求
                        const formData=new FormData()
                        formData.append('video', file)
                        const response=await fetch('/shrimp',{method:'POST',body:formData})
                        if (!response.ok) {
                            throw new Error('Network response was not ok.');  
                        }
                        source.close()
                        const data=await response.json()
                        button.disabled=false
                        if(data.data==='任务被终止'){
                            alert(data.data)
                        }
                        else{
                            progressCircle.style.display='none'
                            $('#distance').text('总路程'+data.distance)
                            // 画图
                            var trace=[{
                                x: data.position_data.map(item=>item[0]),
                                y: data.position_data.map(item=>item[1]),
                                mode:"lines",
                                line:{
                                        color:'rgb(32, 139, 139)'
                                    }
                            }]
                            var layout = {
                                xaxis: {
                                    range: [0, 600],
                                    title: "x(像素)",
                                    titlefont: {  
                                        color: 'rgb(32, 139, 139)' // 轴标签颜色  
                                    },  
                                    linecolor: 'rgb(32, 139, 139)', // 轴线颜色  
                                    tickfont: {  
                                        color: 'lrgb(32, 139, 139)' // 轴刻度标签颜色  
                                    }
                                },
                                yaxis: {range: [0, 600],
                                    title: "y(像素)",
                                    titlefont: {  
                                        color: 'rgb(32, 139, 139)' // 轴标签颜色  
                                    },  
                                    linecolor: 'rgb(32, 139, 139)', // 轴线颜色  
                                    tickfont: {  
                                        color: 'lrgb(32, 139, 139)' // 轴刻度标签颜色  
                                    }
                                },  
                                title: "鳌虾运动轨迹",
                                titlefont:{
                                    color:'rgb(32, 139, 139)'
                                },
                                plot_bgcolor: 'rgba(0,0,0,0)',
                                paper_bgcolor:'rgba(0,0,0,0)'
                                }
                            Plotly.newPlot("s2", trace, layout,{scrollZoom: true,editable: true }) 
                        }
                    }catch(error){
                        button.disabled = false
                        if(error.message.startsWith('Failed to fetch')){}
                        else{alert(error)}
                    }
                }
            
            </script>

            <fieldset>
                <legend>输出结果</legend>
                <P id="distance">总路程:</P>
            </fieldset>
            <fieldset>
                <legend>注意事项</legend>
                <p>本程序运行将消耗大量算力和内存,最好使用高配电脑。不支持windows10以下的操作系统。在后台有任务在跑时,切勿重复上传视频,
                    等待后台跑完出图时再上传新的视频。如果选错视频并上传了,请点击'终止运行'再重新上传视频。有问题联系wx:m989783106</p>
            </fieldset>
        </section>

        <section id="s2">
            <div id="progress-circle"><p id="progress-num"></p></div>
        </section>
    </main>
</body>
</html>

plotly.js从官网下载

代码分览

总体设计是以<html>和<body>为底,<main>为主容器内使用grid2列布局,2个<section>作为内容器占据左右2个网格。

左边的<section>容纳文件上传表单,功能按钮,数据显示,使用说明

        <section id="s1">
            <form id="form" enctype="multipart/form-data">
                <fieldset>
                    <legend>选择你要检测的视频</legend>
                    <input type="file" accept=".mp4" id="video" name="vedio">
                </fieldset>
                <fieldset>
                    <legend>功能按键</legend>
                    <button onclick="submit_to()" id="submit">开始上传</button>
                    <button onclick="stopRun()">终止运行</button>
                </fieldset>
            </form>

            <fieldset>
                <legend>输出结果</legend>
                <P id="distance">总路程:</P>
            </fieldset>
            <fieldset>
                <legend>注意事项</legend>
                <p>本程序运行将消耗大量算力和内存,最好使用高配电脑。不支持windows10以下的操作系统。在后台有任务在跑时,切勿重复上传视频,
                    等待后台跑完出图时再上传新的视频。如果选错视频并上传了,请点击'终止运行'再重新上传视频。有问题联系wx:m989783106</p>
            </fieldset>
        </section>

之间用<fieldset>做了区域划分,简单又美观。

<button>均使用onclick属性进行触发

在上传前会检测用户是否选择文件,是否选择的是MP4文件

// 获取文件
const input=document.getElementById('video')
const file=input.files[0]
if (!file){
    throw new Error('未选择文件')
}
if(file.type!=='video/mp4'){
    throw new Error('请选择MP4文件')
}

 一共有3个请求:

  • 请求上传文件,将MP4上传给后端,然后后端运行模型发送质心坐标给前端渲染
  • 请求终止程序,当用户想终止后端运行模型,重新上传文件时
  • 请求获取模型处理进度,后端返回进度给前端,前端进行渲染展示

画轨迹图,前端用plotly.js将质心坐标进行渲染,同时轨迹图还有一定的交互能力。

// 画图
var trace=[{
    x: data.position_data.map(item=>item[0]),
    y: data.position_data.map(item=>item[1]),
    mode:"lines",
    line:{
            color:'rgb(32, 139, 139)'
        }
}]
var layout = {
    xaxis: {
        range: [0, 600],
        title: "x(像素)",
        titlefont: {  
            color: 'rgb(32, 139, 139)' // 轴标签颜色  
        },  
        linecolor: 'rgb(32, 139, 139)', // 轴线颜色  
        tickfont: {  
            color: 'lrgb(32, 139, 139)' // 轴刻度标签颜色  
        }
    },
    yaxis: {range: [0, 600],
        title: "y(像素)",
        titlefont: {  
            color: 'rgb(32, 139, 139)' // 轴标签颜色  
        },  
        linecolor: 'rgb(32, 139, 139)', // 轴线颜色  
        tickfont: {  
            color: 'lrgb(32, 139, 139)' // 轴刻度标签颜色  
        }
    },  
    title: "鳌虾运动轨迹",
    titlefont:{
        color:'rgb(32, 139, 139)'
    },
    plot_bgcolor: 'rgba(0,0,0,0)',
    paper_bgcolor:'rgba(0,0,0,0)'
    }
Plotly.newPlot("s2", trace, layout,{scrollZoom: true,editable: true })

其余的就是代码的排布顺序,异步执行调度,错误处理能力,系统稳定性,用户交互能力的提升,细节很多,均包含在代码中


右边的<section>容纳进度圈,轨迹图

        <section id="s2">
            <div id="progress-circle"><p id="progress-num"></p></div>
        </section>

在文件上传时,就初始化渲染进度条,然后异步请求获取进度,渲染到页面;当进度到达一定值,比如99%,就关闭获取进度的请求,同时设置进度条的display=none。当用户打断程序执行或者重新运行程序,就清理轨迹图,初始化进度条,循环往复。

后端设计

后端整体使用flask,jinjia模板,将flask与pywebview结合。把模型检测代码封装到一个类TrackShrimp,其余的就是各种请求函数。

代码总览

import webview
from flask import Flask, request, jsonify,render_template,stream_with_context,Response
import os
import time
import cv2
from onnxruntime import InferenceSession
import numpy as np
from werkzeug.utils import secure_filename

class TrackShrimp():
    def __init__(self,video_path,model_path,onnx_threshold=0.7):
        # 获取帧数据
        self.cap=self.init_video(video_path)
        frame_width=int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        frame_height=int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        # 图形尺寸
        im_shape = np.array([[frame_height, frame_width]], dtype='float32')
        # y轴缩放量
        self.im_scale_y=640.0/frame_height
        # x轴缩放量
        self.im_scale_x=640.0/frame_width
        scale_factor = np.array([[self.im_scale_y,self.im_scale_x]]).astype('float32')
        # 定义模型输入
        self.inputs_dict = {
            'im_shape': im_shape,
            'image': None,
            'scale_factor': scale_factor
            }
        # 初始化模型
        self.sess=self.init_session(model_path)
        # 模型输出阈值
        self.onnx_threshold=onnx_threshold

    def init_video(self,input_path):
        cap=cv2.VideoCapture(input_path)
        if not cap.isOpened():
            raise ValueError(f'无法打开视频{input_path}')
        return cap

    def init_session(self,model_path):
        try:
            return InferenceSession(model_path, providers=['CUDAExecutionProvider']) 
        except:
            return InferenceSession(model_path, providers=['CPUExecutionProvider'])
        

    def precess_img(self,frame):
        img = cv2.resize(frame, None,None,fx=self.im_scale_x,fy=self.im_scale_y,interpolation=2)
        img = img.astype(np.float32) / 255.0
        img = np.transpose(img, [2, 0, 1])
        img = img[np.newaxis, :, :, :]
        return img
 
    def postcess(self,results:np.ndarray,all_centers:list[np.ndarray]):
        results=results[(results[:, 0] == 0) & (results[:, 1] > self.onnx_threshold)]
        x_centers = (results[:, 2] + results[:, 4]) / 2
        y_centers = (results[:, 3] + results[:, 5]) / 2
        centers = np.column_stack((x_centers, y_centers))
        all_centers.extend(centers)

    def by_smoothfilter(self,centers:list[np.ndarray],window_size=24):
        """
        :param centers: list[np.ndarray,np.ndarray,...]
        :param window_size: 平滑窗口大小
        :return: 平滑后的质心坐标NumPy数组
        """
        centers=np.stack(centers)
        # 计算滑动窗口的平均值,pad函数在序列前后补零以处理边界情况
        padded_centers = np.pad(centers, ((window_size//2, window_size//2), (0, 0)), mode='edge')
        window_sum = np.cumsum(padded_centers, axis=0)
        smoothed_centers = (window_sum[window_size:] - window_sum[:-window_size]) / window_size
        return smoothed_centers

    def calculate_distance(self,centers:np.ndarray):
        '''
        centers:np.ndarray n*2
        '''
        # 计算相邻点之间的差
        diffs = centers[1:] - centers[:-1]
        # 计算每个差值的欧几里得距离
        distances = np.linalg.norm(diffs, axis=1)
        # 计算总路程
        return int(np.sum(distances))

    def gain_position(self,centers:np.ndarray):
        position_list=centers.tolist()
        return position_list
    
    def run(self):
        global schedule
        global run_task
        # 帧数
        frame_count=int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
        frame_number=0
        center_list=[]
        for frame_number in range(frame_count):
            if not run_task:
                return
            success, frame = self.cap.read()
            if not success:
                break
            schedule=int(frame_number/frame_count*100)
            # 打印进度
            if frame_number%10==0:
                print('Process: ',schedule)
            # 图片预处理
            img=self.precess_img(frame)
            self.inputs_dict['image']=img
            results=self.sess.run(None,self.inputs_dict)[0]
            if results is not None:
                self.postcess(results,center_list)
        # 使用平滑滤波
        filtered_centers = self.by_smoothfilter(center_list)
        self.cap.release()

        # 返回路程,轨迹坐标
        return self.calculate_distance(filtered_centers),self.gain_position(filtered_centers)


app = Flask(__name__)
UPLOAD_FOLDER = './temp'
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
schedule=0
run_task=False

def run_flask():
    app.run(debug=False, threaded=True,host='127.0.0.1',port=5000)

def video_process(video_path):
    return TrackShrimp(video_path,'./model.onnx').run()

# 主页面
@app.route('/',methods=['POST','GET'])
def return_main_page():
    return render_template('index.html')

# 检测视频页面
@app.route('/shrimp',methods=['POST'])
def shrimp_track():
    global run_task
    run_task=True
    file=request.files.get('video')
    filename = secure_filename(file.filename)
    video_path=os.path.join(app.config['UPLOAD_FOLDER'],filename)
    file.save(video_path)
    try:
        results=video_process(video_path)
        if results is not None:
            distance,position_data=results
            data = {
                'distance': distance,
                'position_data': position_data
            }
            run_task=False
            return jsonify(data)
        else:
            return jsonify({'data':'任务被终止'})
    except Exception as e:
        print('error:',e)
        return jsonify({'data':'任务被终止'})
    finally:
        if os.path.exists(video_path):
            os.remove(video_path)

@app.route('/stopRun',methods=['GET','POST'])
def stopRun():
    global run_task
    global schedule
    if run_task:
        run_task=False
        schedule=0
        return jsonify({'data':'正在停止任务'})
    else:
        return jsonify({'data':'当前没有任务运行'})

# 进度查询路由
@app.route('/progress',methods=['GET'])
def progress():
    @stream_with_context
    def generate():
        global run_task
        ratio = schedule
        while ratio < 95 and run_task:
            yield "data:" + str(ratio) + "\n\n"
            ratio = schedule
            time.sleep(5)
    return Response(generate(), mimetype='text/event-stream')

if __name__=='__main__':
    # 启动后端  
    # flask_thread = threading.Thread(target=run_flask)  
    # flask_thread.start()
    # time.sleep(1)
    # 启动前端
    webview.create_window('鳌虾轨迹侦测',url=app,width=900,height=600)
    # webview.create_window('鳌虾轨迹侦测',url=f'http://127.0.0.1:5000',width=900,height=600)
    webview.start()

代码分览

一个onnx部署的类TrackShrimp,详细见前面。

一些常量的定义

app = Flask(__name__)
UPLOAD_FOLDER = './temp' # 文件的上传路径,后端需要该路径保留用户上传的文件
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
schedule=0 # 实时进度,初始化进度为0
run_task=False # 一个onnx模型是否在运行的标志,用于接收用户中断信号从而终止模型运行

定义一个flask的·启动函数,用于web调试,浏览器F12启动调试窗口

def run_flask():
    app.run(debug=False, threaded=True,host='127.0.0.1',port=5000)

主页面的请求函数,该页面为主要的UI

# 主页面
@app.route('/',methods=['POST','GET'])
def return_main_page():
    return render_template('index.html')

用户请求中断的请求函数

首先通过标志位(run_task)检测模型是否在跑,如果检测到模型正在运行,就把标志位设为False,然后把进度归0

@app.route('/stopRun',methods=['GET','POST'])
def stopRun():
    global run_task
    global schedule
    if run_task:
        run_task=False
        schedule=0
        return jsonify({'data':'正在停止任务'})
    else:
        return jsonify({'data':'当前没有任务运行'})

进度查询

这里设置当进度为95%时,就停止查询。

@app.route('/progress',methods=['GET'])
def progress():
    @stream_with_context
    def generate():
        global run_task
        ratio = schedule
        while ratio < 95 and run_task:
            yield "data:" + str(ratio) + "\n\n"
            ratio = schedule
            time.sleep(5)
    return Response(generate(), mimetype='text/event-stream')

一个检测的入口函数

def video_process(video_path):
    return TrackShrimp(video_path,'./model.onnx').run()

接收用户上传文件的函数

一旦用户上传文件,就设置运行标志位为True,然后将文件保存,再送入模型运行接口函数,当用户请求终止时,results为None,所以使用if else进行区分。模型结果出来后就把标志位设为False,同时将数据传到前端

@app.route('/shrimp',methods=['POST'])
def shrimp_track():
    global run_task
    run_task=True
    file=request.files.get('video')
    filename = secure_filename(file.filename)
    video_path=os.path.join(app.config['UPLOAD_FOLDER'],filename)
    file.save(video_path)
    try:
        results=video_process(video_path)
        if results is not None:
            distance,position_data=results
            data = {
                'distance': distance,
                'position_data': position_data
            }
            run_task=False
            return jsonify(data)
        else:
            return jsonify({'data':'任务被终止'})
    except Exception as e:
        print('error:',e)
        return jsonify({'data':'任务被终止'})
    finally:
        if os.path.exists(video_path):
            os.remove(video_path)

接着就是启动所有代码了,为了调试方便,我写了2份代码,一份用于调试,一份用于成品

if __name__=='__main__':
    # 启动前端
    webview.create_window('鳌虾轨迹侦测',url=app,width=900,height=600)
    webview.start()
if __name__=='__main__':
    # 启动后端  
    flask_thread = threading.Thread(target=run_flask)  
    flask_thread.start()
    time.sleep(1)
    # 启动前端
    webview.create_window('鳌虾轨迹侦测',url=f'http://127.0.0.1:5000',width=900,height=600)
    webview.start()

pyinstaller打包

进入项目目录,命令行输入

piinstaller -D -w main.py

找到生成的main.spec文件,按如下修改

# -*- mode: python ; coding: utf-8 -*-


a = Analysis(
    ['main.py'],
    pathex=[],
    binaries=[],
    datas=[('templates/','templates/'),('static/','static/'),('venv/Lib/site-packages/onnxruntime/capi/onnxruntime_providers_shared.dll','onnxruntime/capi/'),('venv/Lib/site-packages/onnxruntime/capi/onnxruntime_providers_cuda.dll','onnxruntime/capi/')],
    hiddenimports=[],
    hookspath=[],
    hooksconfig={},
    runtime_hooks=[],
    excludes=[],
    noarchive=False,
    optimize=0,
)
pyz = PYZ(a.pure)

exe = EXE(
    pyz,
    a.scripts,
    [],
    exclude_binaries=True,
    name='main',
    debug=False,
    bootloader_ignore_signals=False,
    strip=False,
    upx=True,
    console=False,
    disable_windowed_traceback=False,
    argv_emulation=False,
    target_arch=None,
    codesign_identity=None,
    entitlements_file=None,
    icon='1.ico'
)
coll = COLLECT(
    exe,
    a.binaries,
    a.datas,
    strip=False,
    upx=True,
    upx_exclude=[],
    name='main',
)

在项目目录下放置一个图标命名为1.ico,最好是48*48像素

然后命令行运行

pyinstaller main.spec

然后在venv中找到 onnxruntime_gpu-1.18.1.dist-info 文件夹,复制到 dist/main/_internal 中

同时在cuDNN中找到如下几个动态链接库,复制到 dist/main/_internal 中

cudnn_ops_infer64_8.dll
cudnn_cnn_infer64_8.dll
cudnn_adv_infer64_8.dll
cudnn64_8.dll
cudart64_110.dll
cublasLt64_11.dll
cublas64_11.dll
cufft64_10.dll

然后将model.onnx放到 dist/main/ ,并在该目录创建一个目录temp

最后处理的结果如下

XXX
    dist
        main
            _internal
            main.exe
            model.onnx
            temp

生成安装包

使用into setup软件,并在网站找到中文的语言包下载为 Chinese.isl 文件,放到intosetup软件安装目录的 Languages 文件夹下

接着如图所示

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

取消立即编译,先进入文件里修改一些东西

修改成下面这样 

点击编译

然后就生成了安装包,就可以在任何win10,win11电脑里用CPU跑了,如果安装的电脑 有显卡和CUDA并把CUDA添加到了环境变量,就可以用GPU跑了

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值