【Python学习】为深度学习的训练过程加一个桌面监视器

1. 背景

        在深度神经网络的学习过程中,是否碰到过训练过程太长,把训练程序放在后台,不知道训练进度,经常要打开查看的情况

2. 目的

        写一个小程序,注册为Tensorflow的Keras框架的fit函数的回调,然后再训练过程中自动调用显示窗口在左面左下角,显示step和epoch进度,当训练完成后,窗口还会闪烁提醒

        

3. 代码

import sys
from PyQt5.QtWidgets import QApplication, QMainWindow, QVBoxLayout, QLabel, QProgressBar
from PyQt5.QtCore import Qt, QTimer
import PyQt5.QtWidgets as wid
from PyQt5.QtGui import QFont, QPalette, QColor, QPainter
from keras.callbacks import Callback
from PyQt5.QtWidgets import QMainWindow, QDesktopWidget, QApplication
from PyQt5.QtCore import QRect
import win32gui
import win32con


class set_window_foreground:
    def __init__(self, window: QMainWindow):
        self.hwnd_map = None
        self.window = window
        self.get_process_count = 0

    def get_all_hwnd(self, hwnd, mouse):
        """
        获取后台运行的所有程序
        """
        if (win32gui.IsWindow(hwnd) and
                win32gui.IsWindowEnabled(hwnd) and
                win32gui.IsWindowVisible(hwnd)):
            self.hwnd_map.update({hwnd: win32gui.GetWindowText(hwnd)})

    def list_window(self):

        try:
            self.hwnd_map = {}
            win32gui.EnumWindows(self.get_all_hwnd, 0)
        except Exception as exx:
            print("\nlll", exx)

    def get_center_pos(rect: QRect):
        x = rect.x()
        y = rect.y()
        center_x = int(x + rect.width() / 2)
        center_y = int(y + rect.height() / 2)
        return [center_x, center_y]

    def Window_top(self):
        """
        设置指定窗口置顶
        """
        try:
            if (self.get_process_count < 10):
                self.list_window()
                self.get_process_count += 1
            App_name = str(self.window.windowTitle())
            for h, t in self.hwnd_map.items():
                if t != "":
                    if t.find(App_name) != -1:
                        # h 为想要放到最前面的窗口句
                        win32gui.BringWindowToTop(h)
                        # shell = win32com.client.Dispatch("WScript.Shell")
                        # shell.SendKeys('%')
                        # 被其他窗口遮挡,调用后放到最前面
                        win32gui.SetWindowPos(h,
                                              win32con.HWND_TOPMOST,
                                              self.window.geometry().x(),
                                              self.window.geometry().y(),
                                              self.window.geometry().width(),
                                              self.window.geometry().height(),
                                              win32con.SWP_SHOWWINDOW | win32con.SWP_NOMOVE | win32con.SWP_NOSIZE)
                        # 解决被最小化的情况
                        #win32gui.ShowWindow(h, win32con.SW_RESTORE)

        except Exception as ee:
            print("\nWindow_top", ee)

    def cancel_Window_top(self):
        """
        取消指定窗口置顶
        """
        try:
            App_name = str(self.window.windowTitle())
            for h, t in self.hwnd_map.items():
                if t != "":
                    if t.find(App_name) != -1:
                        # h 为想要放到最前面的窗口句柄
                        print(h, t)

                        # win32gui.BringWindowToTop(h)
                        # shell = win32com.client.Dispatch("WScript.Shell")
                        # shell.SendKeys('%')
                        # 取消置顶
                        win32gui.SetWindowPos(h,
                                              win32con.HWND_NOTOPMOST,
                                              self.window.geometry().x(),
                                              self.window.geometry().y(),
                                              self.window.geometry().width(),
                                              self.window.geometry().height(),
                                              win32con.SWP_SHOWWINDOW)
                        # 解决被最小化的情况
                        win32gui.ShowWindow(h, win32con.SW_RESTORE)
        except Exception as ee:
            print("\ncancel_Window_top:", ee)


class TrainingProgressWindow(QMainWindow):

    def update_window(self):
        self.window_control.Window_top()

    def update_background(self):
        if (self.flash_count > 10):
            self.timer_blink.stop()
            self.close()
        self.blink_state = (self.blink_state + 1) % 2
        self.setStyleSheet("")
        self.repaint()
        self.flash_count += 1

    def paintEvent(self, event):
        painter = QPainter(self)
        if self.blink_state == 0:
            painter.setBrush(QColor(255, 255, 0))
        else:
            painter.setBrush(QColor(78, 79, 79))
        painter.drawRect(self.rect())
        
    def __init__(self):
        super().__init__()
        self.flash_count = 0
        self.blink_state = 0
        self.timer_blink = QTimer()
        self.timer_blink.timeout.connect(self.update_background)
        self.window_control = set_window_foreground(self)
        self.setWindowTitle("Training Progress")
        self.timer = QTimer()
        self.timer.timeout.connect(self.update_window)
        self.time_out = 2000
        self.timer.start(self.time_out)
        screen = QDesktopWidget().screenGeometry()
        self.setGeometry(screen.width() - 300, screen.height() - 300, 300, 200)
        layout = QVBoxLayout()

        self.label_train = QLabel("Training Progress: 0%")
        font = QFont("Comic Sans MS", 15)
        self.label_train.setFont(font)
        self.label_train.setStyleSheet('color: white;')
        layout.addWidget(self.label_train)

        self.progress_bar_epoch = QProgressBar()
        self.progress_bar_epoch.setValue(0)
        self.progress_bar_epoch.setMaximum(100)
        self.progress_bar_epoch.setStyleSheet("QProgressBar {"
                                              "border: 2px solid white;"
                                              "border-radius: 10px;"
                                              "text-align: right;}"
                                              "QProgressBar::chunk {"
                                              "background-color: #05B8CC;"
                                              # "foreground-color: rgb(177, 177, 177);"
                                              "width: 10px;"
                                              "margin: 0.5px;}")

        # 获取QProgressBar对象的QPalette对象
        palette = self.progress_bar_epoch.palette()
        # 设置进度条控件的文本颜色为红色
        palette.setColor(QPalette.Text, QColor("white"))
        # 将修改后的QPalette对象应用到进度条控件上
        self.progress_bar_epoch.setPalette(palette)
        layout.addWidget(self.progress_bar_epoch)

        self.label_batch = QLabel("Step Progress: 0%")
        self.label_batch.setFont(font)
        self.label_batch.setStyleSheet('color: white;')
        layout.addWidget(self.label_batch)

        self.progress_bar_batch = QProgressBar()
        self.progress_bar_batch.setValue(0)
        self.progress_bar_batch.setMaximum(100)
        self.progress_bar_batch.setStyleSheet("QProgressBar {"
                                              "border: 2px solid white;"
                                              "border-radius: 10px;"
                                              "text-align: right;}"
                                              "QProgressBar::chunk {"
                                              "background-color: #05B8CC;"
                                              # "foreground-color: rgb(177, 177, 177);"
                                              "width: 10px;"
                                              "margin: 0.5px;}")
        # 获取QProgressBar对象的QPalette对象
        palette = self.progress_bar_batch.palette()
        # 设置进度条控件的文本颜色为红色
        palette.setColor(QPalette.Text, QColor("white"))
        # 将修改后的QPalette对象应用到进度条控件上
        self.progress_bar_batch.setPalette(palette)
        layout.addWidget(self.progress_bar_batch)

        central_widget = wid.QWidget()
        central_widget.setLayout(layout)
        self.setCentralWidget(central_widget)
        self.setWindowFlags(Qt.FramelessWindowHint)
        self.setStyleSheet(  # "border-radius: 10px; border: 2px solid black;"
            "background-color: rgb(89, 89, 89);")

        self.current_epoch = 0
        self.current_step = 0

    def update_train_info(self, epoch):
        self.label_train.setText("Training Progress {} :".format(epoch))
        self.current_epoch = epoch

    def update_batch_info(self, batch):
        self.label_batch.setText("Step Progress {} :".format(batch))
        self.current_step = batch

    def update_train_progress(self, progress):
        self.label_train.setText("Training Progress {} : ".format(
            self.current_epoch) + str(progress) + "%")
        self.progress_bar_epoch.setValue(int(progress))

    def update_batch_progress(self, progress):
        self.label_batch.setText("Step Progress {} : ".format(
            self.current_step) + str(progress) + "%")
        self.progress_bar_batch.setValue(int(progress))
        
    def update_train_end(self):
        self.blink_state = 0
        self.flash_count = 0
        self.timer_blink.start(500)  # 设置闪烁间隔为500毫秒
        

'''实验性,希望把训练进度显示在窗口上'''


class ShowProgreeCallback(Callback):
    def __init__(self):
        super().__init__()
        self.progress_window = TrainingProgressWindow()
        self.progress_window.show()

    def on_epoch_end(self, epoch, logs=None):
        self.progress_window.update_train_info(epoch + 1)
        progress = (epoch + 1) * 100 / self.params["epochs"]
        progress = round(progress, 2)
        self.progress_window.update_train_progress(progress)
        QApplication.processEvents()

    def on_train_batch_end(self, batch, logs=None):
        self.progress_window.update_batch_info(batch + 1)
        progress = (batch + 1) * 100 / int(self.params["steps"])
        progress = round(progress, 2)
        self.progress_window.update_batch_progress(progress)
        QApplication.processEvents()
        
    def on_train_end(self, logs=None):
        self.progress_window.update_train_end()
        QApplication.processEvents()


if __name__ == "__main__":
    app = QApplication(sys.argv)
    window = TrainingProgressWindow()
    window.show()
    sys.exit(app.exec_())

4. 使用方法

model.compile(
        optimizer=get_optimizer(),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy']
    )
    
    custom_callback = progress_window.ShowProgreeCallback()
    epoch = 1

    history = model.fit(
        train_ds,
        validation_data=val_ds,
        epochs=epoch,
        callbacks=[custom_callback]
    )

详细解释

这段代码定义了一个名为TrainingProgressWindow的类,它是一个用于显示训练进度的窗口。这个窗口有两个进度条,一个用于显示训练的进度(epoch),另一个用于显示每个批次的进度(batch)。

当创建这个类的实例时,它会设置一些基本的窗口属性,如标题、大小和位置,并添加两个标签和一个进度条到窗口中。这些标签用于显示训练和批次的进度百分比,而进度条则用于显示实际的进度。

此外,这个类还定义了一些方法,如paintEventupdate_train_end等。其中,paintEvent方法是用于绘制窗口背景的,它会根据blink_state的值来设置画刷的颜色。update_train_end方法用于在训练结束时更新窗口的状态。

最后,这个类还定义了两个回调函数:on_epoch_endon_train_batch_end。这两个函数分别在每个epoch结束时和每个批次结束时被调用,它们会更新窗口中的进度信息,并触发窗口的重绘。

总的来说,这个类的主要作用是创建一个窗口,用于实时显示训练的进度。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值