import threading
import torch
from PyQt5 import QtCore, QtGui, QtWidgets
from PyQt5.QtCore import QThread, pyqtSignal, QObject, Qt
from PyQt5.QtWidgets import QWidget, QApplication, QTableWidget, QTableWidgetItem, QLabel, QMessageBox
import sys
from ultralytics import YOLO
class YoloTrainingThread(QObject):
update_loss = pyqtSignal(float, int)
log = pyqtSignal(str)
finished = pyqtSignal()
error = pyqtSignal(str)
def __init__(self, config):
super().__init__()
self.config = config
self.stop_signal = False
self._is_running = True
self.model = YOLO("./yolov8s-cls.pt")
def run(self):
try:
self.model.add_callback('on_train_batch_end', self.on_train_batch_end)
results = self.model.train(data='D:/DeepModel/ultralytics-cls/data-defects1',
epochs=int(self.config['epochs']),
imgsz=int(self.config['imgsz']),
batch=int(self.config['batch']),
optimizer="SGD",
)
except Exception as e:
self.error.emit(str(e))
finally:
self.model.close()
torch.cuda.empty_cache()
def on_train_batch_end(self, trainer):
""" 每批次训练结束时自动触发 """
current_loss = trainer.loss # 直接访问训练器记录的loss值
current_epoch = trainer.epoch
self.update_loss.emit(current_loss, current_epoch)
def train_stop(self):
self._is_running = False
class ui_Qwidget(QWidget):
def __init__(self):
super().__init__()
self.setUI()
self.thread = None
self.config = {'data': '',
'epochs': '',
'imgsz': '',
'batch': '',
}
self.worker = YoloTrainingThread(self.config)
self.thread = QThread()
self.worker.moveToThread(self.thread)
self.thread.started.connect(self.worker.run)
self.worker.update_loss.connect(self.update_loss_display)
def setUI(self):
# 表格
self.tableWidget = QTableWidget(self)
self.tableWidget.setGeometry(100, 80, 400, 300)
self.tableWidget.setObjectName("tableWidget")
self.tableWidget.setColumnCount(2)
self.tableWidget.setRowCount(5)
item = QtWidgets.QTableWidgetItem()
font = QtGui.QFont()
font.setPointSize(9)
item.setFont(font)
item = QtWidgets.QTableWidgetItem()
item.setTextAlignment(QtCore.Qt.AlignCenter)
self.tableWidget.setHorizontalHeaderItem(0, item)
item = QtWidgets.QTableWidgetItem()
item.setTextAlignment(QtCore.Qt.AlignCenter)
self.tableWidget.setHorizontalHeaderItem(1, item)
_translate = QtCore.QCoreApplication.translate
item = self.tableWidget.horizontalHeaderItem(0)
item.setText(_translate("widget", "Name"))
item = self.tableWidget.horizontalHeaderItem(1)
item.setText(_translate("widget", "Param"))
item = QTableWidgetItem('data')
self.tableWidget.setItem(0, 0, item)
item = QTableWidgetItem('epochs')
self.tableWidget.setItem(1, 0, item)
item = QTableWidgetItem('imgsz')
self.tableWidget.setItem(2, 0, item)
item = QTableWidgetItem('batch')
self.tableWidget.setItem(3, 0, item)
item = QTableWidgetItem('train_ratio')
self.tableWidget.setItem(4, 0, item)
self.tableWidget.itemChanged.connect(self.TabelConnect)
self.tableWidget.horizontalHeader().resizeSection(0, 100)
self.tableWidget.horizontalHeader().resizeSection(1, 270)
# 按钮
self.trainBtn1 = QtWidgets.QPushButton('选择文件夹', self)
self.trainBtn1.setGeometry(100, 20, 100, 30)
self.trainBtn2 = QtWidgets.QPushButton('新建文件夹', self)
self.trainBtn2.setGeometry(230, 20, 100, 30)
self.trainBtn3 = QtWidgets.QPushButton('训练', self)
self.trainBtn3.setGeometry(1000, 80, 100, 30)
self.trainBtn3.clicked.connect(self.trainStart)
self.trainBtn4 = QtWidgets.QPushButton('推理', self)
self.trainBtn4.setGeometry(1000, 140, 100, 30)
self.trainBtn5 = QtWidgets.QPushButton('结束', self)
self.trainBtn5.setGeometry(1000, 200, 100, 30)
self.trainBtn5.clicked.connect(self.closeWindow)
self.lossLabelName = QLabel('loss:', self)
self.lossLabelName.setGeometry(600, 80, 100, 30)
self.lossLabel = QLabel(self)
self.lossLabel.setGeometry(650, 80, 100, 30)
self.epochLabelName = QLabel('epoch:', self)
self.epochLabelName.setGeometry(600, 120, 100, 30)
self.epochLabel = QLabel(self)
self.epochLabel.setGeometry(650, 120, 100, 30)
self.TrainingStatus = QLabel(self)
self.TrainingStatus.setGeometry(1000, 20, 100, 30)
def TabelConnect(self, item):
current_row = item.row()
current_col = item.column()
data = item.text()
data1 = self.tableWidget.item(current_row, current_col - 1).text()
if data1 == 'epochs':
self.config['epochs'] = data
elif data1 == 'imgsz':
self.config['imgsz'] = data
elif data1 == 'batch':
self.config['batch'] = data
def trainStart(self):
self.trainBtn3.setEnabled(False)
self.thread.start()
def update_loss_display(self, loss, epoch):
"""更新Loss显示槽函数"""
self.lossLabel.setText(f"{loss:.3f}")
self.epochLabel.setText(f"{epoch}")
def closeWindow(self):
if self.thread and self.thread.isRunning():
self.worker.blockSignals(True)
self.worker.update_loss.disconnect()
self.worker.train_stop()
self.thread.quit()
self.thread.wait(3000)
QtCore.QTimer.singleShot(1000, self.force_cleanup)
def force_cleanup(self):
# 释放资源
del self.worker
del self.thread
self.worker = None
self.thread = None
if __name__ == '__main__':
a = QApplication(sys.argv)
w = ui_Qwidget()
w.setWindowTitle("训练软件")
w.resize(1200, 800)
w.show()
a.exec()
结束训练界面会崩溃是为什么