之前的学习笔记是调用训练好的结果来做识别,分为加载本地图片识别和调用usb摄像头实时识别(IP摄像头暂时不可用);但是首先有了训练才能有训练好的模型文件供我们使用。加之训练过程比较复杂,调用多个脚本,上手不便;制作训练用的软件一方面是方便自己使用,另一方面也对自己是个锻炼。软件最终的界面如下图所示,可以使用IP、USB摄像头实时将图片显示在界面中,并在界面中实时进行标记(类似于labelImg软件),标记结束后将标准xml文件、原始图像保存在软件脚本所在的目录下,xml保存于annotatis文件夹,图像保存于img文件夹,同时自动分配训练集及验证集,并生成对应的tfrecord格式,这些数据一并保存在data文件夹下。在标记过程中亦可自动生成lable map文件。
软件目前需加载预训练模型进行训练(重新训练按钮暂无作用),使用ssd_mobilenet_v2_coco的预训练数据集(API原始代码就带)。软件可以完成从图像采集----->标注图像----->生成所需数据----->进行训练----->生成pb模型文件的整个过程。
存在问题就是目前想完成自动修改config文件的功能,用户在界面中指定训练次数以及batchsize,但是还不能实现。
软件需进一步优化。所有的源代码如下,本软件只是将object detection API的各个脚本封装在了一起。
# -*- coding: utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import hashlib
import io
import logging
import sys
import cv2
import os
import random
import PIL.Image
import tensorflow as tf
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import functools
import json
from lxml import etree
from PyQt5 import QtCore, QtGui, QtWidgets
from PyQt5.QtWidgets import *
from PyQt5.QtCore import *
from PyQt5.QtGui import *
from xml.dom.minidom import Document
from object_detection.utils import dataset_util
from object_detection.utils import label_map_util
from object_detection.utils import config_util
from object_detection.builders import dataset_builder
from object_detection.builders import graph_rewriter_builder
from object_detection.builders import model_builder
from object_detection.legacy import trainer
from google.protobuf import text_format
from object_detection import exporter
from object_detection.protos import pipeline_pb2
class MyLabel(QLabel):
x0 = 0
y0 = 0
x1 = 0
y1 = 0
flag = False
#鼠标点击事件
def mousePressEvent(self,event):
global x_start
global y_start
if event.buttons () == QtCore.Qt.LeftButton:
self.flag = True
self.x0 = event.x()
self.y0 = event.y()
x_start = self.x0
y_start = self.y0
#鼠标释放事件
def mouseReleaseEvent(self,event):
self.flag = False
#鼠标移动事件
def mouseMoveEvent(self,event):
global x_end
global y_end
if event.buttons () == QtCore.Qt.LeftButton:
if self.flag:
self.x1 = event.x()
self.y1 = event.y()
self.update()
x_end = self.x1
y_end = self.y1
#绘制事件
def paintEvent(self, event):
super().paintEvent(event)
rect =QRect(self.x0, self.y0, abs(self.x1-self.x0), abs(self.y1-self.y0))
painter = QPainter(self)
painter.setPen(QPen(Qt.green,2,Qt.SolidLine))
painter.drawRect(rect)
class Ui_train_window(QtWidgets.QWidget):
def setupUi(self):
self.setObjectName("train_window")
self.resize(690, 600)
self.setMinimumSize(QtCore.QSize(690, 600))
self.setMaximumSize(QtCore.QSize(690, 600))
self.horizontalLayoutWidget = QtWidgets.QWidget(self)
self.horizontalLayoutWidget.setGeometry(QtCore.QRect(0, 10, 681, 80))
self.horizontalLayoutWidget.setObjectName("horizontalLayoutWidget")
self.horizontalLayout = QtWidgets.QHBoxLayout(self.horizontalLayoutWidget)
self.horizontalLayout.setContentsMargins(0, 0, 0, 0)
self.horizontalLayout.setObjectName("horizontalLayout")
self.lab_enter_name = QtWidgets.QLabel(self.horizontalLayoutWidget)
self.lab_enter_name.setObjectName("lab_enter_name")
self.horizontalLayout.addWidget(self.lab_enter_name)
self.le_username = QtWidgets.QLineEdit(self.horizontalLayoutWidget)
self.le_username.setObjectName("le_username")
self.horizontalLayout.addWidget(self.le_username)
self.lab_enter_pw = QtWidgets.QLabel(self.horizontalLayoutWidget)
self.lab_enter_pw.setObjectName("lab_enter_pw")
self.horizontalLayout.addWidget(self.lab_enter_pw)
self.le_userpw = QtWidgets.QLineEdit(self.horizontalLayoutWidget)
self.le_userpw.setObjectName("le_userpw")
self.le_userpw.setEchoMode(QLineEdit.Password)
self.horizontalLayout.addWidget(self.le_userpw)
self.label = QtWidgets.QLabel(self.horizontalLayoutWidget)
self.label.setObjectName("label")
self.horizontalLayout.addWidget(self.label)
self.le_ipadr = QtWidgets.QLineEdit(self.horizontalLayoutWidget)
self.le_ipadr.setObjectName("le_ipadr")
self.horizontalLayout.addWidget(self.le_ipadr)
self.btn_openIPcam = QtWidgets.QPushButton(self.horizontalLayoutWidget)
self.btn_openIPcam.setMinimumSize(QtCore.QSize(50, 10))
self.btn_openIPcam.setObjectName("btn_openIPcam")
self.horizontalLayout.addWidget(self.btn_openIPcam)
self.showpic= MyLabel(self)
self.showpic.setGeometry(QtCore.QRect(15, 110, 500, 400))
self.showpic.setMinimumSize(QtCore.QSize(500, 400))
self.showpic.setMaximumSize(QtCore.QSize(500, 400))
self.showpic.setObjectName("show")
self.showpic.setStyleSheet(("border:2px solid lightgray"))
self.verticalLayoutWidget = QtWidgets.QWidget(self)
self.verticalLayoutWidget.setGeometry(QtCore.QRect(560, 100, 111, 381))
self.verticalLayoutWidget.setObjectName("verticalLayoutWidget")
self.verticalLayout = QtWidgets.QVBoxLayout(self.verticalLayoutWidget)
self.verticalLayout.setContentsMargins(0, 0, 0, 0)
self.verticalLayout.setObjectName("verticalLayout")
self.btn_start = QtWidgets.QPushButton(self.verticalLayoutWidget)
self.btn_start.setObjectName("btn_start")
self.verticalLayout.addWidget(self.btn_start)
self.lab_herelable = QtWidgets.QLabel(self.verticalLayoutWidget)
self.lab_herelable.setObjectName("lab_herelable")
self.lab_herelable.setMaximumSize(QtCore.QSize(16777215, 15))
self.verticalLayout.addWidget(self.lab_herelable)
self.combo_label = QtWidgets.QComboBox(self.verticalLayoutWidget)
self.combo_label.setObjectName("combo_label")
self.combo_label.addItem("")
self.combo_label.addItem("")
self.combo_label.addItem("")
self.combo_label.addItem("")
self.combo_label.addItem("")
self.combo_label.addItem("")
self.verticalLayout.addWidget(self.combo_label)
self.btn_save = QtWidgets.QPushButton(self.verticalLayoutWidget)
self.btn_save.setObjectName("btn_save")
self.verticalLayout.addWidget(self.btn_save)
self.btn_finish = QtWidgets.QPushButton(self.verticalLayoutWidget)
self.btn_finish.setObjectName("btn_finish")
self.verticalLayout.addWidget(self.btn_finish)
self.horizontalLay