[置顶] Java多线程编程模式实战指南(一):Active Object模式(上)

本文详细介绍了Active Object模式,一种提高并发性的异步编程模式。它通过将方法调用与执行分离,实现了任务提交和执行的解耦。文章通过具体的案例分析了Active Object模式的架构,包括Proxy、MethodRequest、ActivationQueue、Scheduler和Servant等组件,以及其实现并发控制和任务执行的过程。最后,文章以电信软件彩信短号模块为例,展示了如何在实际开发中应用Active Object模式实现请求缓存功能。

本文由黄文海首次发布在infoq中文站上:http://www.infoq.com/cn/articles/Java-multithreaded-programming-mode-active-object-part1 。转载请注明作者: 黄文海 出处:http://viscent.iteye.com。

 

Active Object模式简介

Active Object模式是一种异步编程模式。它通过对方法的调用与方法的执行进行解耦来提高并发性。若以任务的概念来说,Active Object模式的核心则是它允许任务的提交(相当于对异步方法的调用)和任务的执行(相当于异步方法的真正执行)分离。这有点类似于 System.gc()这个方法:客户端代码调用完gc()后,一个进行垃圾回收的任务被提交,但此时JVM并不一定进行了垃圾回收,而可能是在gc() 方法调用返回后的某段时间才开始执行任务——回收垃圾。我们知道,System.gc()的调用方代码是运行在自己的线程上(通常是main线程派生的子 线程),而JVM的垃圾回收这个动作则由专门的线程(垃圾回收线程)来执行的。换言之,System.gc()这个方法所代表的动作(其所定义的功能)的 调用方和执行方是运行在不同的线程中的,从而提高了并发性。

再进一步介绍Active Object模式,我们可先简单地将其核心理解为一个名为ActiveObject的类,该类对外暴露了一些异步方法,如图1所示。

图 1. ActiveObject对象示例

doSomething方法的调用方和执行方运行在各自的线程上。在并发的环境下,doSomething方法会被多个线程调用。这时所需的线程安 全控制封装在doSomething方法背后,使得调用方代码无需关心这点,从而简化了调用方代码:从调用方代码来看,调用一个Active Object对象的方法与调用普通Java对象的方法并无太大差别。如清单1所示。

清单 1. Active Object方法调用示例

ActiveObject ao=...;
Future future = ao.doSomething("data");
//执行其它操作
String result = future.get();
System.out.println(result);

Active Object模式的架构

当Active Object模式对外暴露的异步方法被调用时,与该方法调用相关的上下文信息,包括被调用的异步方法名(或其代表的操作)、调用方代码所传递的参数等,会 被封装成一个对象。该对象被称为方法请求(Method Request)。方法请求对象会被存入Active Object模式所维护的缓冲区(Activation Queue)中,并由专门的工作线程负责根据其包含的上下文信息执行相应的操作。也就是说,方法请求对象是由运行调用方代码的线程通过调用Active Object模式对外暴露的异步方法生成的,而方法请求所代表的操作则由专门的线程来执行,从而实现了方法的调用与执行的分离,产生了并发。

Active Object模式的主要参与者有以下几种。其类图如图2所示。

图 2. Active Object模式的类图

(点击图像放大)

  • Proxy:负责对外暴露异步方法接口。当调用方代码调用该参与者实例的异步方法doSomething时,该方法会生成一个相 应的MethodRequest实例并将其存储到Scheduler所维护的缓冲区中。doSomething方法的返回值是一个表示其执行结果的外包装 对象:Future参与者的实例。异步方法doSomething运行在调用方代码所在的线程中。
  • MethodRequest:负责将调用方代码对Proxy实例的异步方法的调用封装为一个对象。该对象保留了异步方法的名称及调用方代码传递的参数等上下文信息。它使得将Proxy的异步方法的调用和执行分离成为可能。其call方法会根据其所包含上下文信息调用Servant实例的相应方法。
  • ActivationQueue:负责临时存储由Proxy的异步方法被调用时所创建的MethodRequest实例的缓冲区。
  • Scheduler:负责将Proxy的异步方法所创建的MethodRequest实例存入其维护的缓冲区中。并根据一定的调 度策略,对其维护的缓冲区中的MethodRequest实例进行执行。其调度策略可以根据实际需要来定,如FIFO、LIFO和根据 MethodRequest中包含的信息所定的优先级等。
  • Servant:负责对Proxy所暴露的异步方法的具体实现。
  • Future:负责存储和返回Active Object异步方法的执行结果。

Active Object模式的序列图如图3所示。

图 3. Active Object模式的序列图

(点击图像放大)

第1步:调用方代码调用Proxy的异步方法doSomething。

第2~7步:doSomething方法创建Future实例作为该方法的返回值。并将调用方代码对该方法的调用封装为MethodRequest 对象。然后以所创建的MethodRequest对象作为参数调用Scheduler的enqueue方法,以将MethodRequest对象存入缓冲 区。Scheduler的enqueue方法会调用Scheduler所维护的ActivationQueue实例的enqueue方法,将 MethodRequest对象存入缓冲区。

第8步:doSomething返回其所创建的Future实例。

第9步:Scheduler实例采用专门的工作线程运行dispatch方法。

第10~12步:dispatch方法调用ActivationQueue实例的dequeue方法,获取一个MethodRequest对象。然后调用MethodRequest对象的call方法

第13~16步:MethodRequest对象的call方法调用与其关联的Servant实例的相应方法doSomething。并将Servant.doSomething方法的返回值设置到Future实例上。

第17步:MethodRequest对象的call方法返回。

上述步骤中,第1~8步是运行在Active Object的调用者线程中的,这几个步骤实现了将调用方代码对Active Object所提供的异步方法的调用封装成对象(Method Request),并将其存入缓冲区。这几个步骤实现了任务的提交。第9~17步是运行在Active Object的工作线程中,这些步骤实现从缓冲区中读取Method Request,并对其进行执行,实现了任务的执行。从而实现了Active Object对外暴露的异步方法的调用与执行的分离。

如果调用方代码关心Active Object的异步方法的返回值,则可以在其需要时,调用Future实例的get方法来获得异步方法的真正执行结果。

Active Object模式实战案例

某电信软件有一个彩信短号模块。其主要功能是实现手机用户给其它手机用户发送彩信时,接收方号码可以填写为对方的短号。例如,用户13612345678给其同事13787654321发送彩信时,可以将接收方号码填写为对方的短号,如776,而非其真实的号码。

该模块处理其接收到的下发彩信请求的一个关键操作是查询数据库以获得接收方短号对应的真实号码(长号)。该操作可能因为数据库故障而失败,从而使整 个请求无法继续被处理。而数据库故障是可恢复的故障,因此在短号转换为长号的过程中如果出现数据库异常,可以先将整个下发彩信请求消息缓存到磁盘中,等到 数据库恢复后,再从磁盘中读取请求消息,进行重试。为方便起见,我们可以通过Java的对象序列化API,将表示下发彩信的对象序列化到磁盘文件中从而实 现请求缓存。下面我们讨论这个请求缓存操作还需要考虑的其它因素,以及Active Object模式如何帮助我们满足这些考虑。

首先,请求消息缓存到磁盘中涉及文件I/O这种慢的操作,我们不希望它在请求处理的主线程(即Web服务器的工作线程)中执行。因为这样会使该模块 的响应延时增大,降低系统的响应性。并使得Web服务器的工作线程因等待文件I/O而降低了系统的吞吐量。这时,异步处理就派上用场了。Active Object模式可以帮助我们实现请求缓存这个任务的提交和执行分离:任务的提交是在Web服务器的工作线程中完成,而任务的执行(包括序列化对象到磁盘 文件中等操作)则是在Active Object工作线程中执行。这样,请求处理的主线程在侦测到短号转长号失败时即可以触发对当前彩信下发请求进行缓存,接着继续其请求处理,如给客户端响 应。而此时,当前请求消息可能正在被Active Object线程缓存到文件中。如图4所示。

图 4 .异步实现缓存

其次,每个短号转长号失败的彩信下发请求消息会被缓存为一个磁盘文件。但我们不希望这些缓存文件被存在同一个子目录下。而是希望多个缓存文件会被存 储到多个子目录中。每个子目录最多可以存储指定个数(如2000个)的缓存文件。若当前子目录已存满,则新建一个子目录存放新的缓存文件,直到该子目录也 存满,依此类推。当这些子目录的个数到达指定数量(如100个)时,最老的子目录(连同其下的缓存文件,如果有的话)会被删除。从而保证子目录的个数也是 固定的。显然,在并发环境下,实现这种控制需要一些并发访问控制(如通过锁来控制),但是我们不希望这种控制暴露给处理请求的其它代码。而Active Object模式中的Proxy参与者可以帮助我们封装并发访问控制。

下面,我们看该案例的相关代码通过应用Active Object模式在实现缓存功能时满足上述两个目标。首先看请求处理的入口类。该类就是本案例的Active Object模式的客调用方代码。如清单2所示。

清单 2. 彩信下发请求处理的入口类

public class MMSDeliveryServlet extends HttpServlet {

	private static final long serialVersionUID = 5886933373599895099L;

	@Override
	public void doPost(HttpServletRequest req, HttpServletResponse resp)
			throws ServletException, IOException {
		//将请求中的数据解析为内部对象
		MMSDeliverRequest mmsDeliverReq = this.parseRequest(req.getInputStream());
		Recipient shortNumberRecipient = mmsDeliverReq.getRecipient();
		Recipient originalNumberRecipient = null;

		try {
			// 将接收方短号转换为长号
			originalNumberRecipient = convertShortNumber(shortNumberRecipient);
		} catch (SQLException e) {

			// 接收方短号转换为长号时发生数据库异常,触发请求消息的缓存
			AsyncRequestPersistence.getInstance().store(mmsDeliverReq);

			// 继续对当前请求的其它处理,如给客户端响应
			resp.setStatus(202);
		}

	}

	private MMSDeliverRequest parseRequest(InputStream reqInputStream) {
		MMSDeliverRequest mmsDeliverReq = new MMSDeliverRequest();
		//省略其它代码
		return mmsDeliverReq;
	}

	private Recipient convertShortNumber(Recipient shortNumberRecipient)
			throws SQLException {
		Recipient recipent = null;
		//省略其它代码
		return recipent;
	}

}

清单2中的doPost方法在侦测到短号转换过程中发生的数据库异常后,通过调用AsyncRequestPersistence类的store方 法触发对彩信下发请求消息的缓存。这里,AsyncRequestPersistence类相当于Active Object模式中的Proxy参与者。尽管本案例涉及的是一个并发环境,但从清单2中的代码可见,AsyncRequestPersistence类的 调用方代码无需处理多线程同步问题。这是因为多线程同步问题被封装在AsyncRequestPersistence类之后。

AsyncRequestPersistence类的代码如清单3所示。

清单 3. 彩信下发请求缓存入口类(Active Object模式的Proxy)

// ActiveObjectPattern.Proxy
public class AsyncRequestPersistence implements RequestPersistence {
	private static final long ONE_MINUTE_IN_SECONDS = 60;
	private final Logger logger;
	private final AtomicLong taskTimeConsumedPerInterval = new AtomicLong(0);
	private final AtomicInteger requestSubmittedPerIterval = new AtomicInteger(0);
	
	// ActiveObjectPattern.Servant
	private final DiskbasedRequestPersistence 
						delegate = new DiskbasedRequestPersistence();
	// ActiveObjectPattern.Scheduler
	private final ThreadPoolExecutor scheduler;

	private static class InstanceHolder {
		final static RequestPersistence INSTANCE = new AsyncRequestPersistence();
	}

	private AsyncRequestPersistence() {
		logger = Logger.getLogger(AsyncRequestPersistence.class);
		scheduler = new ThreadPoolExecutor(1, 3, 
				60 * ONE_MINUTE_IN_SECONDS,
				TimeUnit.SECONDS,
				// ActiveObjectPattern.ActivationQueue
				new LinkedBlockingQueue(200), 
				new ThreadFactory() {
					@Override
					public Thread newThread(Runnable r) {
						Thread t;
						t = new Thread(r, "AsyncRequestPersistence");
						return t;
					}

				});

		scheduler.setRejectedExecutionHandler(
				new ThreadPoolExecutor.DiscardOldestPolicy());

		// 启动队列监控定时任务
		Timer monitorTimer = new Timer(true);
		monitorTimer.scheduleAtFixedRate(
		    new TimerTask() {

			@Override
			public void run() {
				if (logger.isInfoEnabled()) {

					logger.info("task count:" 
							+ requestSubmittedPerIterval
							+ ",Queue size:" 
							+ scheduler.getQueue().size()
							+ ",taskTimeConsumedPerInterval:"
							+ taskTimeConsumedPerInterval.get() 
							+ " ms");
				}

				taskTimeConsumedPerInterval.set(0);
				requestSubmittedPerIterval.set(0);

			}
		}, 0, ONE_MINUTE_IN_SECONDS * 1000);
	}

	public static RequestPersistence getInstance() {
		return InstanceHolder.INSTANCE;
	}

	@Override
	public void store(final MMSDeliverRequest request) {
		/*
		 * 将对store方法的调用封装成MethodRequest对象, 并存入缓冲区。
		 */
		// ActiveObjectPattern.MethodRequest
		Callable methodRequest = new Callable() {
			@Override
			public Boolean call() throws Exception {
				long start = System.currentTimeMillis();
				try {
					delegate.store(request);
				} finally {
					taskTimeConsumedPerInterval.addAndGet(
							System.currentTimeMillis() - start);
				}

				return Boolean.TRUE;
			}

		};
		scheduler.submit(methodRequest);

		requestSubmittedPerIterval.incrementAndGet();
	}

}

AsyncRequestPersistence类所实现的接口RequestPersistence定义了Active Object对外暴露的异步方法:store方法。由于本案例不关心请求缓存的结果,故该方法没有返回值。其代码如清单4所示。

清单 4. RequestPersistence接口源码

public interface RequestPersistence {

	 void store(MMSDeliverRequest request);
}

AsyncRequestPersistence类的实例变量scheduler相当于Active Object模式中的Scheduler参与者实例。这里我们直接使用了JDK1.5引入的Executor Framework中的ThreadPoolExecutor。在ThreadPoolExecutor类的实例化时,其构造器的第5个参数 (BlockingQueue<Runnable> workQueue)我们指定了一个有界阻塞队列:new LinkedBlockingQueue<Runnable>(200)。该队列相当于Active Object模式中的ActivationQueue参与者实例。

AsyncRequestPersistence类的实例变量delegate相当于Active Object模式中的Servant参与者实例。

AsyncRequestPersistence类的store方法利用匿名类生成一个 java.util.concurrent.Callable实例methodRequest。该实例相当于Active Object模式中的MethodRequest参与者实例。利用闭包(Closure),该实例封装了对store方法调用的上下文信息(包括调用参 数、所调用的方法对应的操作信息)。AsyncRequestPersistence类的store方法通过调用scheduler的submit方法, 将methodRequest送入ThreadPoolExecutor所维护的缓冲区(阻塞队列)中。确切地说,ThreadPoolExecutor 是Scheduler参与者的一个“近似”实现。ThreadPoolExecutor的submit方法相对于Scheduler的enqueue方 法,该方法用于接纳MethodRequest对象,以将其存入缓冲区。当ThreadPoolExecutor当前使用的线程数量小于其核心线程数量 时,submit方法所接收的任务会直接被新建的线程执行。当ThreadPoolExecutor当前使用的线程数量大于其核心线程数时,submit 方法所接收的任务才会被存入其维护的阻塞队列中。不过,ThreadPoolExecutor的这种任务处理机制,并不妨碍我们将它用作 Scheduler的实现。

methodRequest的call方法会调用delegate的store方法来真正实现请求缓存功能。delegate实例对应的类DiskbasedRequestPersistence是请求消息缓存功能的真正实现者。其代码如清单5所示。

清单 5. DiskbasedRequestPersistence类的源码

public class DiskbasedRequestPersistence implements RequestPersistence {
	// 负责缓存文件的存储管理
	private final SectionBasedDiskStorage storage = new SectionBasedDiskStorage();
	private final Logger logger = Logger
	                                 .getLogger(DiskbasedRequestPersistence.class);

	@Override
	public void store(MMSDeliverRequest request) {
		// 申请缓存文件的文件名
		String[] fileNameParts = storage.apply4Filename(request);
		File file = new File(fileNameParts[0]);
		try {
			ObjectOutputStream objOut = new ObjectOutputStream(
			new FileOutputStream(file));
			try {
				objOut.writeObject(request);
			} finally {
				objOut.close();
			}
		} catch (FileNotFoundException e) {
			storage.decrementSectionFileCount(fileNameParts[1]);
			logger.error("Failed to store request", e);
		} catch (IOException e) {
			storage.decrementSectionFileCount(fileNameParts[1]);
			logger.error("Failed to store request", e);
		}

	}

	class SectionBasedDiskStorage {
		private Deque sectionNames = new LinkedList();
		/*
		 * Key->value: 存储子目录名->子目录下缓存文件计数器
		 */
		private Map sectionFileCountMap 
						= new HashMap();
		private int maxFilesPerSection = 2000;
		private int maxSectionCount = 100;
		private String storageBaseDir = System.getProperty("user.dir") + "/vpn";

		private final Object sectionLock = new Object();

		public String[] apply4Filename(MMSDeliverRequest request) {
			String sectionName;
			int iFileCount;
			boolean need2RemoveSection = false;
			String[] fileName = new String[2];
			synchronized (sectionLock) {
				//获取当前的存储子目录名
				sectionName = this.getSectionName();
				AtomicInteger fileCount;
				fileCount = sectionFileCountMap.get(sectionName);
				iFileCount = fileCount.get();
				//当前存储子目录已满
				if (iFileCount >= maxFilesPerSection) {
					if (sectionNames.size() >= maxSectionCount) {
						need2RemoveSection = true;
					}
					//创建新的存储子目录
					sectionName = this.makeNewSectionDir();
					fileCount = sectionFileCountMap.get(sectionName);

				}
				iFileCount = fileCount.addAndGet(1);

			}

			fileName[0] = storageBaseDir + "/" + sectionName + "/"
			    + new DecimalFormat("0000").format(iFileCount) + "-"
			    + request.getTimeStamp().getTime() / 1000 + "-" 
                               + request.getExpiry()
			    + ".rq";
			fileName[1] = sectionName;

			if (need2RemoveSection) {
				//删除最老的存储子目录
				String oldestSectionName = sectionNames.removeFirst();
				this.removeSection(oldestSectionName);
			}

			return fileName;
		}

		public void decrementSectionFileCount(String sectionName) {
			AtomicInteger fileCount = sectionFileCountMap.get(sectionName);
			if (null != fileCount) {
				fileCount.decrementAndGet();
			}
		}

		private boolean removeSection(String sectionName) {
			boolean result = true;
			File dir = new File(storageBaseDir + "/" + sectionName);
			for (File file : dir.listFiles()) {
				result = result && file.delete();
			}
			result = result && dir.delete();
			return result;
		}

		private String getSectionName() {
			String sectionName;

			if (sectionNames.isEmpty()) {
				sectionName = this.makeNewSectionDir();

			} else {
				sectionName = sectionNames.getLast();
			}

			return sectionName;
		}

		private String makeNewSectionDir() {
			String sectionName;
			SimpleDateFormat sdf = new SimpleDateFormat("MMddHHmmss");
			sectionName = sdf.format(new Date());
			File dir = new File(storageBaseDir + "/" + sectionName);
			if (dir.mkdir()) {
				sectionNames.addLast(sectionName);
				sectionFileCountMap.put(sectionName, new AtomicInteger(0));
			} else {
				throw new RuntimeException(
                                 "Cannot create section dir " + sectionName);
			}

			return sectionName;
		}
	}
}

methodRequest的call方法的调用者代码是运行在ThreadPoolExecutor所维护的工作者线程中,这就保证了store 方法的调用方和真正的执行方是分别运行在不同的线程中:服务器工作线程负责触发请求消息缓存,ThreadPoolExecutor所维护的工作线程负责 将请求消息序列化到磁盘文件中。

DiskbasedRequestPersistence类的store方法中调用的SectionBasedDiskStorage类的 apply4Filename方法包含了一些多线程同步控制代码(见清单5)。这部分控制由于是封装在 DiskbasedRequestPersistence的内部类中,对于该类之外的代码是不可见的。因 此,AsyncRequestPersistence的调用方代码无法知道该细节,这体现了Active Object模式对并发访问控制的封装。

小结

本篇介绍了Active Object模式的意图及架构,并以一个实际的案例展示了该模式的代码实现。下篇将对Active Object模式进行评价,并结合本文案例介绍实际运用Active Object模式时需要注意的一些事项。


感谢张龙对本文的审校。

import sys import cv2 import time import torch import traceback import threading import queue import dxcam import ctypes import os import glob import numpy as np import logitech.lg from PyQt6.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton, QComboBox, QSlider, QSpinBox, QDoubleSpinBox, QLineEdit, QTabWidget, QGroupBox, QTextEdit, QFileDialog, QMessageBox, QSizePolicy, QSplitter, QDialog, QScrollArea) from PyQt6.QtCore import Qt, QTimer, QThread, pyqtSignal from PyQt6.QtGui import QImage, QPixmap, QPainter, QColor, QFont, QIcon, QKeyEvent, QMouseEvent from PyQt6.QtSvg import QSvgRenderer from PIL import Image from ultralytics import YOLO from pynput import mouse class PIDController: """PID控制器""" def __init__(self, kp, ki, kd, output_min=-100, output_max=100): self.kp = kp # 比例增益 self.ki = ki # 积分增益 self.kd = kd # 微分增益 self.output_min = output_min self.output_max = output_max # 状态变量 self.integral = 0.0 self.prev_error = 0.0 self.last_time = time.perf_counter() def compute(self, setpoint, current_value): """计算PID控制输出""" current_time = time.perf_counter() dt = current_time - self.last_time # 防止过小的时间差导致计算问题 MIN_DT = 0.0001 if dt < MIN_DT: dt = MIN_DT # 计算误差 error = setpoint - current_value # 比例项 P = self.kp * error # 积分项(防饱和) self.integral += error * dt I = self.ki * self.integral # 微分项 derivative = (error - self.prev_error) / dt D = self.kd * derivative # 合成输出 output = P + I + D # 输出限幅 if output > self.output_max: output = self.output_max elif output < self.output_min: output = self.output_min # 更新状态 self.prev_error = error self.last_time = current_time return output def reset(self): """重置控制器状态""" self.integral = 0.0 self.prev_error = 0.0 self.last_time = time.perf_counter() class ScreenDetector: def __init__(self, config_path): # 解析配置文件 self._parse_config(config_path) # 设备检测与模型加载 self.device = self._determine_device() self.model = YOLO(self.model_path).to(self.device) # 屏幕信息初始化 self._init_screen_info() # 控制参数初始化 self._init_control_params() # 状态管理 self.stop_event = threading.Event() self.camera_lock = threading.Lock() self.target_lock = threading.Lock() self.offset_lock = threading.Lock() self.button_lock = threading.Lock() # 推理状态控制 self.inference_active = False self.inference_lock = threading.Lock() # 初始化相机 self._init_camera() # 初始化鼠标监听器 self._init_mouse_listener() # 初始化PID控制器 self._init_pid_controllers() def _parse_config(self, config_path): """解析并存储配置参数""" self.cfg = self._parse_txt_config(config_path) # 存储常用参数 self.model_path = self.cfg['model_path'] self.model_device = self.cfg['model_device'] self.screen_target_size = int(self.cfg['screen_target_size']) self.detection_conf_thres = float(self.cfg['detection_conf_thres']) self.detection_iou_thres = float(self.cfg['detection_iou_thres']) self.detection_classes = [int(x) for x in self.cfg['detection_classes'].split(',')] self.visualization_color = tuple(map(int, self.cfg['visualization_color'].split(','))) self.visualization_line_width = int(self.cfg['visualization_line_width']) self.visualization_font_scale = float(self.cfg['visualization_font_scale']) self.visualization_show_conf = bool(self.cfg['visualization_show_conf']) self.fov_horizontal = float(self.cfg.get('move_fov_horizontal', '90')) self.mouse_dpi = int(self.cfg.get('move_mouse_dpi', '400')) self.target_offset_x_percent = float(self.cfg.get('target_offset_x', '50')) self.target_offset_y_percent = 100 - float(self.cfg.get('target_offset_y', '50')) # PID参数 self.pid_kp = float(self.cfg.get('pid_kp', '1.0')) self.pid_ki = float(self.cfg.get('pid_ki', '0.05')) self.pid_kd = float(self.cfg.get('pid_kd', '0.2')) # 贝塞尔曲线参数 self.bezier_steps = int(self.cfg.get('bezier_steps', '100')) self.bezier_duration = float(self.cfg.get('bezier_duration', '0.1')) self.bezier_curve = float(self.cfg.get('bezier_curve', '0.3')) def update_config(self, config_path): """动态更新配置""" try: # 重新解析配置文件 self._parse_config(config_path) # 更新可以直接修改的参数 self.detection_conf_thres = float(self.cfg['detection_conf_thres']) self.detection_iou_thres = float(self.cfg['detection_iou_thres']) self.target_offset_x_percent = float(self.cfg.get('target_offset_x', '50')) self.target_offset_y_percent = 100 - float(self.cfg.get('target_offset_y', '50')) # PID参数更新 self.pid_kp = float(self.cfg.get('pid_kp', '1.0')) self.pid_ki = float(self.cfg.get('pid_ki', '0.05')) self.pid_kd = float(self.cfg.get('pid_kd', '0.2')) # 更新PID控制器 self.pid_x = PIDController(self.pid_kp, self.pid_ki, self.pid_kd) self.pid_y = PIDController(self.pid_kp, self.pid_ki, self.pid_kd) # FOV和DPI更新 self.fov_horizontal = float(self.cfg.get('move_fov_horizontal', '90')) self.mouse_dpi = int(self.cfg.get('move_mouse_dpi', '400')) # 更新贝塞尔曲线参数 self.bezier_steps = int(self.cfg.get('bezier_steps', '100')) self.bezier_duration = float(self.cfg.get('bezier_duration', '0.1')) self.bezier_curve = float(self.cfg.get('bezier_curve', '0.3')) print("配置已动态更新") return True except Exception as e: print(f"更新配置失败: {str(e)}") traceback.print_exc() return False def _parse_txt_config(self, path): """解析TXT格式的配置文件""" config = {} with open(path, 'r', encoding='utf-8') as f: for line in f: line = line.strip() if not line or line.startswith('#'): continue if '=' in line: key, value = line.split('=', 1) config[key.strip()] = value.strip() return config def _init_pid_controllers(self): """初始化PID控制器""" # 创建XY方向的PID控制器 self.pid_x = PIDController(self.pid_kp, self.pid_ki, self.pid_kd) self.pid_y = PIDController(self.pid_kp, self.pid_ki, self.pid_kd) def start_inference(self): """启动推理""" with self.inference_lock: self.inference_active = True def stop_inference(self): """停止推理""" with self.inference_lock: self.inference_active = False def _determine_device(self): """确定运行设备""" if self.model_device == 'auto': return 'cuda' if torch.cuda.is_available() and torch.cuda.device_count() > 0 else 'cpu' return self.model_device def _init_screen_info(self): """初始化屏幕信息""" user32 = ctypes.windll.user32 self.screen_width, self.screen_height = user32.GetSystemMetrics(0), user32.GetSystemMetrics(1) self.screen_center = (self.screen_width // 2, self.screen_height // 2) # 计算截图区域 left = (self.screen_width - self.screen_target_size) // 2 top = (self.screen_height - self.screen_target_size) // 2 self.region = ( max(0, int(left)), max(0, int(top)), min(self.screen_width, int(left + self.screen_target_size)), min(self.screen_height, int(top + self.screen_target_size)) ) def _init_control_params(self): """初始化控制参数""" self.previous_target_info = None self.closest_target_absolute = None self.target_offset = None self.right_button_pressed = False # 改为鼠标右键状态 def _init_camera(self): """初始化相机""" try: with self.camera_lock: self.camera = dxcam.create( output_idx=0, output_color="BGR", region=self.region ) self.camera.start(target_fps=120, video_mode=True) except Exception as e: print(f"相机初始化失败: {str(e)}") try: # 降级模式 with self.camera_lock: self.camera = dxcam.create() self.camera.start(target_fps=60, video_mode=True) except Exception as fallback_e: print(f"降级模式初始化失败: {str(fallback_e)}") self.camera = None def _init_mouse_listener(self): """初始化鼠标监听器""" self.mouse_listener = mouse.Listener( on_click=self.on_mouse_click # 监听鼠标点击事件 ) self.mouse_listener.daemon = True self.mouse_listener.start() def on_mouse_click(self, x, y, button, pressed): """处理鼠标点击事件""" try: if button == mouse.Button.right: # 监听鼠标右键 with self.button_lock: self.right_button_pressed = pressed # 更新状态 # 当右键释放时重置PID控制器 if not pressed: self.pid_x.reset() self.pid_y.reset() except Exception as e: print(f"鼠标事件处理错误: {str(e)}") def calculate_fov_movement(self, dx, dy): """基于FOV算法计算鼠标移动量""" # 计算屏幕对角线长度 screen_diagonal = (self.screen_width ** 2 + self.screen_height ** 2) ** 0.5 # 计算垂直FOV aspect_ratio = self.screen_width / self.screen_height fov_vertical = self.fov_horizontal / aspect_ratio # 计算每像素对应角度 angle_per_pixel_x = self.fov_horizontal / self.screen_width angle_per_pixel_y = fov_vertical / self.screen_height # 计算角度偏移 angle_offset_x = dx * angle_per_pixel_x angle_offset_y = dy * angle_per_pixel_y # 转换为鼠标移动量 move_x = (angle_offset_x / 360) * self.mouse_dpi move_y = (angle_offset_y / 360) * self.mouse_dpi return move_x, move_y def move_mouse_to_target(self): """移动鼠标对准目标点""" if not self.target_offset: return try: # 获取目标点与屏幕中心的偏移量 with self.offset_lock: dx, dy = self.target_offset # 使用FOV算法将像素偏移转换为鼠标移动量 move_x, move_y = self.calculate_fov_movement(dx, dy) # 使用PID计算平滑的移动量 pid_move_x = self.pid_x.compute(0, -move_x) # 将dx取反 pid_move_y = self.pid_y.compute(0, -move_y) # 将dy取反 # 移动鼠标 if pid_move_x != 0 or pid_move_y != 0: logitech.lg.start_mouse_move(int(pid_move_x), int(pid_move_y), self.bezier_steps, self.bezier_duration, self.bezier_curve) except Exception as e: print(f"移动鼠标时出错: {str(e)}") def run(self, frame_queue): """主检测循环""" while not self.stop_event.is_set(): try: # 检查推理状态 with self.inference_lock: if not self.inference_active: time.sleep(0.01) continue # 截图 grab_start = time.perf_counter() screenshot = self._grab_screenshot() grab_time = (time.perf_counter() - grab_start) * 1000 # ms if screenshot is None: time.sleep(0.001) continue # 推理 inference_start = time.perf_counter() results = self._inference(screenshot) inference_time = (time.perf_counter() - inference_start) * 1000 # ms # 处理检测结果 target_info, closest_target_relative, closest_offset = self._process_detection_results(results) # 更新目标信息 self._update_target_info(target_info, closest_offset) # 移动鼠标 self._move_mouse_if_needed() # 可视化处理 annotated_frame = self._visualize_results(results, closest_target_relative) if frame_queue else None # 放入队列 if frame_queue: try: frame_queue.put( (annotated_frame, len(target_info), inference_time, grab_time, target_info), timeout=0.01 ) except queue.Full: pass except Exception as e: print(f"检测循环异常: {str(e)}") traceback.print_exc() self._reset_camera() time.sleep(0.5) def _grab_screenshot(self): """安全获取截图""" with self.camera_lock: if self.camera: return self.camera.grab() return None def _inference(self, screenshot): """执行模型推理""" return self.model.predict( screenshot, conf=self.detection_conf_thres, iou=self.detection_iou_thres, classes=self.detection_classes, device=self.device, verbose=False ) def _process_detection_results(self, results): """处理检测结果""" target_info = [] min_distance = float('inf') closest_target_relative = None closest_target_absolute = None closest_offset = None for box in results[0].boxes: # 获取边界框坐标 x1, y1, x2, y2 = map(int, box.xyxy[0]) # 计算绝对坐标 x1_abs = x1 + self.region[0] y1_abs = y1 + self.region[1] x2_abs = x2 + self.region[0] y2_abs = y2 + self.region[1] # 计算边界框尺寸 width = x2_abs - x1_abs height = y2_abs - y1_abs # 应用偏移百分比计算目标点 target_x = x1_abs + int(width * (self.target_offset_x_percent / 100)) target_y = y1_abs + int(height * (self.target_offset_y_percent / 100)) # 计算偏移量 dx = target_x - self.screen_center[0] dy = target_y - self.screen_center[1] distance = (dx ** 2 + dy ** 2) ** 0.5 # 更新最近目标 if distance < min_distance: min_distance = distance # 计算相对坐标(用于可视化) closest_target_relative = ( x1 + int(width * (self.target_offset_x_percent / 100)), y1 + int(height * (self.target_offset_y_percent / 100)) ) closest_target_absolute = (target_x, target_y) closest_offset = (dx, dy) # 保存目标信息 class_id = int(box.cls) class_name = self.model.names[class_id] target_info.append(f"{class_name}:{x1_abs},{y1_abs},{x2_abs},{y2_abs}") return target_info, closest_target_relative, closest_offset def _update_target_info(self, target_info, closest_offset): """更新目标信息""" # 检查目标信息是否有变化 if target_info != self.previous_target_info: self.previous_target_info = target_info.copy() print(f"{len(target_info)}|{'|'.join(target_info)}") # 更新目标偏移量 with self.offset_lock: self.target_offset = closest_offset def _visualize_results(self, results, closest_target): """可视化处理结果""" frame = results[0].plot( line_width=self.visualization_line_width, font_size=self.visualization_font_scale, conf=self.visualization_show_conf ) # 绘制最近目标 if closest_target: # 绘制目标中心点 cv2.circle( frame, (int(closest_target[0]), int(closest_target[1])), 3, (0, 0, 255), -1 ) # 计算屏幕中心在截图区域内的相对坐标 screen_center_x = self.screen_center[0] - self.region[0] screen_center_y = self.screen_center[1] - self.region[1] # 绘制中心到目标的连线 cv2.line( frame, (int(screen_center_x), int(screen_center_y)), (int(closest_target[0]), int(closest_target[1])), (0, 255, 0), 1 ) return frame def _move_mouse_if_needed(self): """如果需要则移动鼠标""" with self.button_lock: if self.right_button_pressed and self.target_offset: # 使用right_button_pressed self.move_mouse_to_target() def _reset_camera(self): """重置相机""" print("正在重置相机...") try: self._init_camera() except Exception as e: print(f"相机重置失败: {str(e)}") traceback.print_exc() def stop(self): """安全停止检测器""" self.stop_event.set() self._safe_stop() if hasattr(self, 'mouse_listener') and self.mouse_listener.running: # 改为停止鼠标监听器 self.mouse_listener.stop() def _safe_stop(self): """同步释放资源""" print("正在安全停止相机...") try: with self.camera_lock: if self.camera: self.camera.stop() print("相机已停止") except Exception as e: print(f"停止相机时发生错误: {str(e)}") print("屏幕检测器已停止") class DetectionThread(QThread): update_signal = pyqtSignal(object) def __init__(self, detector, frame_queue): super().__init__() self.detector = detector self.frame_queue = frame_queue self.running = True def run(self): self.detector.run(self.frame_queue) def stop(self): self.running = False self.detector.stop() class MainWindow(QMainWindow): def __init__(self, detector): super().__init__() self.detector = detector self.setWindowTitle("EFAI 1.1") self.setGeometry(100, 100, 600, 400) # 添加缺失的属性初始化 self.visualization_enabled = True self.inference_active = False # 初始推理状态为停止 #窗口置顶 self.setWindowFlag(Qt.WindowType.WindowStaysOnTopHint) # 创建帧队列 self.frame_queue = queue.Queue(maxsize=3) # 初始化UI self.init_ui() # 启动检测线程 self.detection_thread = DetectionThread(self.detector, self.frame_queue) self.detection_thread.start() # 启动UI更新定时器 self.update_timer = QTimer() self.update_timer.timeout.connect(self.update_ui) self.update_timer.start(1) # 每1ms更新次 def toggle_visualization(self): # 实际更新可视化状态属性 self.visualization_enabled = not self.visualization_enabled # 更新按钮文本 if self.visualization_enabled: self.toggle_visualization_btn.setText("禁用可视化") else: self.toggle_visualization_btn.setText("启用可视化") def toggle_inference(self): """切换推理状态""" self.inference_active = not self.inference_active if self.inference_active: self.toggle_inference_btn.setText("停止推理") self.toggle_inference_btn.setStyleSheet(""" QPushButton { background-color: #F44336; color: white; border: none; padding: 8px; border-radius: 4px; font-family: Segoe UI; font-size: 10pt; } """) self.detector.start_inference() else: self.toggle_inference_btn.setText("开始推理") self.toggle_inference_btn.setStyleSheet(""" QPushButton { background-color: #4CAF50; color: white; border: none; padding: 8px; border-radius: 4px; font-family: Segoe UI; font-size: 10pt; } """) self.detector.stop_inference() def init_ui(self): # 主布局 central_widget = QWidget() self.setCentralWidget(central_widget) main_layout = QVBoxLayout(central_widget) # 分割器(左侧图像/目标信息,右侧控制面板) splitter = QSplitter(Qt.Orientation.Horizontal) main_layout.addWidget(splitter) # 左侧区域(图像显示和目标信息) left_widget = QWidget() left_layout = QVBoxLayout(left_widget) # 图像显示区域 self.image_label = QLabel() self.image_label.setAlignment(Qt.AlignmentFlag.AlignCenter) self.image_label.setMinimumSize(320, 320) left_layout.addWidget(self.image_label) # 目标信息区域 self.target_info_text = QTextEdit() self.target_info_text.setReadOnly(True) self.target_info_text.setFixedHeight(150) self.target_info_text.setStyleSheet(""" QTextEdit { background-color: #2D2D30; color: #DCDCDC; font-family: Consolas; font-size: 10pt; border: 1px solid #3F3F46; border-radius: 4px; } """) left_layout.addWidget(self.target_info_text) # 右侧控制面板 right_widget = QWidget() right_layout = QVBoxLayout(right_widget) right_layout.setAlignment(Qt.AlignmentFlag.AlignTop) # 性能信息 perf_group = QGroupBox("性能信息") perf_layout = QVBoxLayout(perf_group) self.target_count_label = QLabel("目标数量: 0") self.inference_time_label = QLabel("推理时间: 0.000s") self.grab_time_label = QLabel("截图时间: 0.000s") for label in [self.target_count_label, self.inference_time_label, self.grab_time_label]: label.setStyleSheet("font-family: Consolas; font-size: 10pt;") perf_layout.addWidget(label) right_layout.addWidget(perf_group) # 系统信息 sys_group = QGroupBox("系统信息") sys_layout = QVBoxLayout(sys_group) # 获取模型名称(只显示文件名) model_name = os.path.basename(self.detector.model_path) # 获取显示器编号(如果配置中有则显示,否则显示默认值0) monitor_index = self.detector.cfg.get('screen_monitor', '0') self.model_label = QLabel(f"模型: {model_name}") self.device_label = QLabel(f"设备: {self.detector.device.upper()}") self.monitor_label = QLabel(f"显示器:{monitor_index}") self.screen_res_label = QLabel(f"屏幕分辨率: {self.detector.screen_width}x{self.detector.screen_height}") self.region_label = QLabel(f"检测区域: {self.detector.region}") for label in [self.model_label, self.device_label, self.monitor_label, self.screen_res_label, self.region_label]: label.setStyleSheet("font-family: Consolas; font-size: 9pt; color: #A0A0A0;") sys_layout.addWidget(label) right_layout.addWidget(sys_group) # 鼠标状态 mouse_group = QGroupBox("自瞄状态") mouse_layout = QVBoxLayout(mouse_group) self.mouse_status = QLabel("未瞄准") self.mouse_status.setStyleSheet(""" QLabel { font-family: Consolas; font-size: 10pt; color: #FF5252; } """) mouse_layout.addWidget(self.mouse_status) right_layout.addWidget(mouse_group) # 控制按钮 btn_group = QGroupBox("控制") btn_layout = QVBoxLayout(btn_group) # 添加推理切换按钮 self.toggle_inference_btn = QPushButton("开始推理") self.toggle_inference_btn.clicked.connect(self.toggle_inference) self.toggle_inference_btn.setStyleSheet(""" QPushButton { background-color: #4CAF50; color: white; border: none; padding: 8px; border-radius: 4px; font-family: Segoe UI; font-size: 10pt; } QPushButton:hover { background-color: #45A049; } QPushButton:pressed { background-color: #3D8B40; } """) btn_layout.addWidget(self.toggle_inference_btn) self.toggle_visualization_btn = QPushButton("禁用可视化") self.toggle_visualization_btn.clicked.connect(self.toggle_visualization) self.settings_btn = QPushButton("设置") self.settings_btn.clicked.connect(self.open_settings) for btn in [self.toggle_visualization_btn, self.settings_btn]: btn.setStyleSheet(""" QPushButton { background-color: #0078D7; color: white; border: none; padding: 8px; border-radius: 4px; font-family: Segoe UI; font-size: 10pt; } QPushButton:hover { background-color: #106EBE; } QPushButton:pressed { background-color: #005A9E; } """) btn_layout.addWidget(btn) right_layout.addWidget(btn_group) # 添加左右区域到分割器 splitter.addWidget(left_widget) splitter.addWidget(right_widget) splitter.setSizes([600, 200]) # 设置样式 self.setStyleSheet(""" QMainWindow { background-color: #252526; } QGroupBox { font-family: Segoe UI; font-size: 10pt; color: #CCCCCC; border: 1px solid #3F3F46; border-radius: 4px; margin-top: 1ex; } QGroupBox::title { subcontrol-origin: margin; left: 10px; padding: 0 5px; background-color: transparent; } """) def open_settings(self): settings_dialog = SettingsDialog(self.detector.cfg, self) settings_dialog.exec() def update_ui(self): try: # 获取最新数据 latest_data = None while not self.frame_queue.empty(): latest_data = self.frame_queue.get_nowait() if latest_data: # 解包数据 frame, targets_count, inference_time, grab_time, target_info = latest_data # 更新性能信息 self.target_count_label.setText(f"目标数量: {targets_count}") self.inference_time_label.setText(f"推理时间: {inference_time / 1000:.3f}s") self.grab_time_label.setText(f"截图时间: {grab_time / 1000:.3f}s") # 更新目标信息 self.display_target_info(target_info) # 更新图像显示 if self.visualization_enabled and frame is not None: # 转换图像为Qt格式 height, width, channel = frame.shape bytes_per_line = 3 * width q_img = QImage(frame.data, width, height, bytes_per_line, QImage.Format.Format_BGR888) pixmap = QPixmap.fromImage(q_img) # 等比例缩放 scaled_pixmap = pixmap.scaled( self.image_label.width(), self.image_label.height(), Qt.AspectRatioMode.KeepAspectRatio, Qt.TransformationMode.SmoothTransformation ) self.image_label.setPixmap(scaled_pixmap) else: # 显示黑色背景 pixmap = QPixmap(self.image_label.size()) pixmap.fill(QColor(0, 0, 0)) self.image_label.setPixmap(pixmap) # 更新鼠标状态 self.update_mouse_status() except Exception as e: print(f"更新UI时出错: {str(e)}") def display_target_info(self, target_info): """在文本框中显示目标信息""" if not target_info: self.target_info_text.setPlainText("无检测目标") return info_text = "目标类别与坐标:\n" for i, data in enumerate(target_info): try: parts = data.split(":", 1) if len(parts) == 2: class_name, coords_str = parts coords = list(map(int, coords_str.split(','))) if len(coords) == 4: display_text = f"{class_name}: [{coords[0]}, {coords[1]}, {coords[2]}, {coords[3]}]" else: display_text = f"坐标格式错误: {data}" else: display_text = f"数据格式错误: {data}" except: display_text = f"解析错误: {data}" info_text += f"{display_text}\n" self.target_info_text.setPlainText(info_text) def update_mouse_status(self): """更新鼠标右键状态显示""" with self.detector.button_lock: if self.detector.right_button_pressed: self.mouse_status.setText("瞄准中") self.mouse_status.setStyleSheet("color: #4CAF50; font-family: Consolas; font-size: 10pt;") else: self.mouse_status.setText("未瞄准") self.mouse_status.setStyleSheet("color: #FF5252; font-family: Consolas; font-size: 10pt;") def closeEvent(self, event): """安全关闭程序""" self.detection_thread.stop() self.detection_thread.wait() event.accept() class SettingsDialog(QDialog): def __init__(self, config, parent=None): super().__init__(parent) self.config = config # 保存原始配置的副本用于比较 self.original_config = config.copy() self.setWindowTitle("设置") self.setGeometry(100, 100, 600, 500) self.init_ui() def init_ui(self): layout = QVBoxLayout() self.setLayout(layout) # 标签页 tabs = QTabWidget() layout.addWidget(tabs) # 检测设置标签页 detection_tab = QWidget() detection_layout = QVBoxLayout(detection_tab) self.create_detection_settings(detection_layout) tabs.addTab(detection_tab, "检测") # 移动设置标签页 move_tab = QWidget() move_layout = QVBoxLayout(move_tab) self.create_move_settings(move_layout) tabs.addTab(move_tab, "FOV") # 目标点设置标签页 target_tab = QWidget() target_layout = QVBoxLayout(target_tab) self.create_target_settings(target_layout) tabs.addTab(target_tab, "目标点") # PID设置标签页 pid_tab = QWidget() pid_layout = QVBoxLayout(pid_tab) self.create_pid_settings(pid_layout) tabs.addTab(pid_tab, "PID") # 贝塞尔曲线设置标签页 bezier_tab = QWidget() bezier_layout = QVBoxLayout(bezier_tab) self.create_bezier_settings(bezier_layout) tabs.addTab(bezier_tab, "贝塞尔曲线") # 按钮区域 btn_layout = QHBoxLayout() layout.addLayout(btn_layout) save_btn = QPushButton("保存配置") save_btn.clicked.connect(self.save_config) cancel_btn = QPushButton("取消") cancel_btn.clicked.connect(self.reject) for btn in [save_btn, cancel_btn]: btn.setStyleSheet(""" QPushButton { background-color: #0078D7; color: white; border: none; padding: 8px 16px; border-radius: 4px; font-family: Segoe UI; font-size: 10pt; } QPushButton:hover { background-color: #106EBE; } QPushButton:pressed { background-color: #005A9E; } """) btn_layout.addWidget(btn) btn_layout.addStretch() def create_detection_settings(self, layout): # 模型选择 model_group = QGroupBox("模型设置") model_layout = QVBoxLayout(model_group) # 获取基础路径 if getattr(sys, 'frozen', False): base_path = sys._MEIPASS else: base_path = os.path.dirname(os.path.abspath(__file__)) # 获取模型文件列表 models_dir = os.path.join(base_path, 'models') model_files = [] if os.path.exists(models_dir): model_files = glob.glob(os.path.join(models_dir, '*.pt')) # 处理模型显示名称 model_display_names = [os.path.basename(f) for f in model_files] if model_files else ["未找到模型文件"] self.model_name_to_path = {os.path.basename(f): f for f in model_files} # 当前配置的模型处理 current_model_path = self.config['model_path'] current_model_name = os.path.basename(current_model_path) # 确保当前模型在列表中 if current_model_name not in model_display_names: model_display_names.append(current_model_name) self.model_name_to_path[current_model_name] = current_model_path # 模型选择下拉框 model_layout.addWidget(QLabel("选择模型:")) self.model_combo = QComboBox() self.model_combo.addItems(model_display_names) self.model_combo.setCurrentText(current_model_name) model_layout.addWidget(self.model_combo) # 设备选择 model_layout.addWidget(QLabel("运行设备:")) self.device_combo = QComboBox() self.device_combo.addItems(['auto', 'cuda', 'cpu']) self.device_combo.setCurrentText(self.config['model_device']) model_layout.addWidget(self.device_combo) layout.addWidget(model_group) # 检测参数 param_group = QGroupBox("检测参数") param_layout = QVBoxLayout(param_group) # 置信度阈值 param_layout.addWidget(QLabel("置信度阈值:")) conf_layout = QHBoxLayout() self.conf_slider = QSlider(Qt.Orientation.Horizontal) self.conf_slider.setRange(10, 100) # 0.1到1.0,步长0.01 self.conf_slider.setValue(int(float(self.config['detection_conf_thres']) * 100)) conf_layout.addWidget(self.conf_slider) self.conf_value = QLabel(f"{float(self.config['detection_conf_thres']):.2f}") self.conf_value.setFixedWidth(50) conf_layout.addWidget(self.conf_value) param_layout.addLayout(conf_layout) # 连接滑块值变化事件 self.conf_slider.valueChanged.connect(lambda value: self.conf_value.setText(f"{value / 100:.2f}")) # IOU阈值 - 改为滑动条 param_layout.addWidget(QLabel("IOU阈值:")) iou_layout = QHBoxLayout() self.iou_slider = QSlider(Qt.Orientation.Horizontal) self.iou_slider.setRange(10, 100) # 0.1到1.0,步长0.01 self.iou_slider.setValue(int(float(self.config['detection_iou_thres']) * 100)) iou_layout.addWidget(self.iou_slider) self.iou_value = QLabel(f"{float(self.config['detection_iou_thres']):.2f}") self.iou_value.setFixedWidth(50) iou_layout.addWidget(self.iou_value) param_layout.addLayout(iou_layout) # 连接滑块值变化事件 self.iou_slider.valueChanged.connect(lambda value: self.iou_value.setText(f"{value / 100:.2f}")) # 检测类别 param_layout.addWidget(QLabel("检测类别 (逗号分隔):")) self.classes_edit = QLineEdit() self.classes_edit.setText(self.config['detection_classes']) param_layout.addWidget(self.classes_edit) layout.addWidget(param_group) # 屏幕设置 screen_group = QGroupBox("屏幕设置") screen_layout = QVBoxLayout(screen_group) # 显示器编号 screen_layout.addWidget(QLabel("显示器编号:")) self.monitor_spin = QSpinBox() self.monitor_spin.setRange(0, 3) # 假设最多支持4个显示器 self.monitor_spin.setValue(int(self.config.get('screen_monitor', '0'))) screen_layout.addWidget(self.monitor_spin) # 屏幕区域大小 screen_layout.addWidget(QLabel("截屏尺寸:")) self.screen_size_spin = QSpinBox() self.screen_size_spin.setRange(100, 2000) self.screen_size_spin.setValue(int(self.config['screen_target_size'])) screen_layout.addWidget(self.screen_size_spin) layout.addWidget(screen_group) layout.addStretch() def create_move_settings(self, layout): group = QGroupBox("鼠标移动参数") group_layout = QVBoxLayout(group) # FOV设置 group_layout.addWidget(QLabel("横向FOV():")) self.fov_spin = QDoubleSpinBox() self.fov_spin.setRange(1, 179) self.fov_spin.setValue(float(self.config.get('move_fov_horizontal', '90'))) group_layout.addWidget(self.fov_spin) # 鼠标DPI group_layout.addWidget(QLabel("鼠标DPI:")) self.dpi_spin = QSpinBox() self.dpi_spin.setRange(100, 20000) self.dpi_spin.setValue(int(self.config.get('move_mouse_dpi', '400'))) group_layout.addWidget(self.dpi_spin) layout.addWidget(group) layout.addStretch() def create_target_settings(self, layout): group = QGroupBox("目标点偏移") group_layout = QVBoxLayout(group) # X轴偏移 - 添加百分比显示 group_layout.addWidget(QLabel("X轴偏移:")) x_layout = QHBoxLayout() self.x_offset_slider = QSlider(Qt.Orientation.Horizontal) self.x_offset_slider.setRange(0, 100) self.x_offset_slider.setValue(int(float(self.config.get('target_offset_x', '50')))) x_layout.addWidget(self.x_offset_slider) self.x_offset_value = QLabel(f"{int(float(self.config.get('target_offset_x', '50')))}%") self.x_offset_value.setFixedWidth(50) x_layout.addWidget(self.x_offset_value) group_layout.addLayout(x_layout) # 连接滑块值变化事件 self.x_offset_slider.valueChanged.connect(lambda value: self.x_offset_value.setText(f"{value}%")) # Y轴偏移 - 添加百分比显示 group_layout.addWidget(QLabel("Y轴偏移:")) y_layout = QHBoxLayout() self.y_offset_slider = QSlider(Qt.Orientation.Horizontal) self.y_offset_slider.setRange(0, 100) self.y_offset_slider.setValue(int(float(self.config.get('target_offset_y', '50')))) y_layout.addWidget(self.y_offset_slider) self.y_offset_value = QLabel(f"{int(float(self.config.get('target_offset_y', '50')))}%") self.y_offset_value.setFixedWidth(50) y_layout.addWidget(self.y_offset_value) group_layout.addLayout(y_layout) # 连接滑块值变化事件 self.y_offset_slider.valueChanged.connect(lambda value: self.y_offset_value.setText(f"{value}%")) # 说明 info_label = QLabel("(0% = 左上角, 50% = 中心, 100% = 右下角)") info_label.setStyleSheet("font-size: 9pt; color: #888888;") group_layout.addWidget(info_label) layout.addWidget(group) layout.addStretch() def create_pid_settings(self, layout): group = QGroupBox("PID参数") group_layout = QVBoxLayout(group) # Kp参数 group_layout.addWidget(QLabel("比例增益(Kp):")) kp_layout = QHBoxLayout() self.kp_slider = QSlider(Qt.Orientation.Horizontal) self.kp_slider.setRange(1, 1000) # 0.01到10.0,步长0.01 self.kp_slider.setValue(int(float(self.config.get('pid_kp', '1.0')) * 100)) kp_layout.addWidget(self.kp_slider) self.kp_value = QLabel(f"{float(self.config.get('pid_kp', '1.0')):.2f}") self.kp_value.setFixedWidth(50) kp_layout.addWidget(self.kp_value) group_layout.addLayout(kp_layout) # 连接滑块值变化事件 self.kp_slider.valueChanged.connect(lambda value: self.kp_value.setText(f"{value / 100:.2f}")) # Ki参数 group_layout.addWidget(QLabel("积分增益(Ki):")) ki_layout = QHBoxLayout() self.ki_slider = QSlider(Qt.Orientation.Horizontal) self.ki_slider.setRange(0, 100) # 0.0000到0.1000,步长0.001 self.ki_slider.setValue(int(float(self.config.get('pid_ki', '0.05')) * 10000)) ki_layout.addWidget(self.ki_slider) self.ki_value = QLabel(f"{float(self.config.get('pid_ki', '0.05')):.4f}") self.ki_value.setFixedWidth(50) ki_layout.addWidget(self.ki_value) group_layout.addLayout(ki_layout) # 连接滑块值变化事件 self.ki_slider.valueChanged.connect(lambda value: self.ki_value.setText(f"{value / 10000:.4f}")) # Kd参数 group_layout.addWidget(QLabel("微分增益(Kd):")) kd_layout = QHBoxLayout() self.kd_slider = QSlider(Qt.Orientation.Horizontal) self.kd_slider.setRange(0, 5000) # 0.000到5.000,步长0.001 self.kd_slider.setValue(int(float(self.config.get('pid_kd', '0.2')) * 1000)) kd_layout.addWidget(self.kd_slider) self.kd_value = QLabel(f"{float(self.config.get('pid_kd', '0.2')):.3f}") self.kd_value.setFixedWidth(50) kd_layout.addWidget(self.kd_value) group_layout.addLayout(kd_layout) # 连接滑块值变化事件 self.kd_slider.valueChanged.connect(lambda value: self.kd_value.setText(f"{value / 1000:.3f}")) # 说明 info_text = "建议调整顺序: Kp → Kd → Ki\n\n" \ "先调整Kp至响应迅速但不过冲\n" \ "再增加Kd抑制震荡\n" \ "最后微调Ki消除剩余误差" info_label = QLabel(info_text) info_label.setStyleSheet("font-size: 9pt; color: #888888;") group_layout.addWidget(info_label) layout.addWidget(group) layout.addStretch() # 创建贝塞尔曲线设置 def create_bezier_settings(self, layout): group = QGroupBox("贝塞尔曲线参数") group_layout = QVBoxLayout(group) # 步数设置 group_layout.addWidget(QLabel("步数 (1-500):")) steps_layout = QHBoxLayout() self.steps_slider = QSlider(Qt.Orientation.Horizontal) self.steps_slider.setRange(1, 500) self.steps_slider.setValue(int(self.config.get('bezier_steps', 100))) steps_layout.addWidget(self.steps_slider) self.steps_value = QLabel(str(self.config.get('bezier_steps', 100))) self.steps_value.setFixedWidth(50) steps_layout.addWidget(self.steps_value) group_layout.addLayout(steps_layout) # 连接滑块值变化事件 self.steps_slider.valueChanged.connect(lambda value: self.steps_value.setText(str(value))) # 总移动时间设置 () group_layout.addWidget(QLabel("总移动时间 (秒, 0.01-1.0):")) duration_layout = QHBoxLayout() self.duration_slider = QSlider(Qt.Orientation.Horizontal) self.duration_slider.setRange(1, 100) # 0.01到1.0,步长0.01 self.duration_slider.setValue(int(float(self.config.get('bezier_duration', 0.1)) * 100)) duration_layout.addWidget(self.duration_slider) self.duration_value = QLabel(f"{float(self.config.get('bezier_duration', 0.1)):.2f}") self.duration_value.setFixedWidth(50) duration_layout.addWidget(self.duration_value) group_layout.addLayout(duration_layout) # 连接滑块值变化事件 self.duration_slider.valueChanged.connect(lambda value: self.duration_value.setText(f"{value / 100:.2f}")) # 控制点偏移幅度 group_layout.addWidget(QLabel("控制点偏移幅度 (0-1):")) curve_layout = QHBoxLayout() self.curve_slider = QSlider(Qt.Orientation.Horizontal) self.curve_slider.setRange(0, 100) # 0.00到1.00,步长0.01 self.curve_slider.setValue(int(float(self.config.get('bezier_curve', 0.3)) * 100)) curve_layout.addWidget(self.curve_slider) self.curve_value = QLabel(f"{float(self.config.get('bezier_curve', 0.3)):.2f}") self.curve_value.setFixedWidth(50) curve_layout.addWidget(self.curve_value) group_layout.addLayout(curve_layout) # 连接滑块值变化事件 self.curve_slider.valueChanged.connect(lambda value: self.curve_value.setText(f"{value / 100:.2f}")) # 说明 info_text = "贝塞尔曲线参数说明:\n\n" \ "• 步数: 鼠标移动的细分步数,值越大移动越平滑\n" \ "• 总移动时间: 鼠标移动的总时间,值越小移动越快\n" \ "• 控制点偏移幅度: 控制贝塞尔曲线的弯曲程度,0为直线,1为最大弯曲" info_label = QLabel(info_text) info_label.setStyleSheet("font-size: 9pt; color: #888888;") group_layout.addWidget(info_label) layout.addWidget(group) layout.addStretch() def save_config(self): try: # 保存配置到字典 model_name = self.model_combo.currentText() model_path = self.model_name_to_path.get(model_name, model_name) self.config['model_path'] = model_path self.config['model_device'] = self.device_combo.currentText() self.config['screen_monitor'] = str(self.monitor_spin.value()) self.config['screen_target_size'] = str(self.screen_size_spin.value()) # 检测参数 self.config['detection_conf_thres'] = str(self.conf_slider.value() / 100) self.config['detection_iou_thres'] = str(self.iou_slider.value() / 100) self.config['detection_classes'] = self.classes_edit.text() # 移动设置 self.config['move_fov_horizontal'] = str(self.fov_spin.value()) self.config['move_mouse_dpi'] = str(self.dpi_spin.value()) # 目标点偏移设置 self.config['target_offset_x'] = str(self.x_offset_slider.value()) self.config['target_offset_y'] = str(self.y_offset_slider.value()) # PID设置 self.config['pid_kp'] = str(self.kp_slider.value() / 100) self.config['pid_ki'] = str(self.ki_slider.value() / 10000) self.config['pid_kd'] = str(self.kd_slider.value() / 1000) # 贝塞尔曲线设置 self.config['bezier_steps'] = str(self.steps_slider.value()) self.config['bezier_duration'] = str(self.duration_slider.value() / 100) self.config['bezier_curve'] = str(self.curve_slider.value() / 100) # 保存为TXT格式 with open('detection_config.txt', 'w', encoding='utf-8') as f: for key, value in self.config.items(): f.write(f"{key} = {value}\n") # 检查需要重启的参数是否被修改 restart_required = False restart_params = [] # 比较模型路径是否变化 if self.config['model_path'] != self.original_config.get('model_path', ''): restart_required = True restart_params.append("模型路径") # 比较设备类型是否变化 if self.config['model_device'] != self.original_config.get('model_device', ''): restart_required = True restart_params.append("设备类型") # 比较屏幕区域大小是否变化 if self.config['screen_target_size'] != self.original_config.get('screen_target_size', ''): restart_required = True restart_params.append("屏幕区域大小") # 比较检测类别是否变化 if self.config['detection_classes'] != self.original_config.get('detection_classes', ''): restart_required = True restart_params.append("检测类别") # 动态更新检测器配置 if self.parent() and hasattr(self.parent(), 'detector'): success = self.parent().detector.update_config('detection_config.txt') if success: if restart_required: # 需要重启的参数已修改 param_list = "、".join(restart_params) QMessageBox.information( self, "配置已保存", f"配置已保存!以下参数需要重启才能生效:\n{param_list}\n\n" "其他参数已实时更新。" ) else: # 所有参数都已实时更新 QMessageBox.information(self, "成功", "配置已实时更新生效!") else: QMessageBox.warning(self, "部分更新", "配置更新失败,请查看日志") else: QMessageBox.information(self, "成功", "配置已保存!部分参数需重启生效") self.accept() except Exception as e: QMessageBox.critical(self, "错误", f"保存配置失败: {str(e)}") if __name__ == "__main__": detector = ScreenDetector('detection_config.txt') print(f"\nDXcam检测器初始化完成 | 设备: {detector.device.upper()}") app = QApplication(sys.argv) # 设置全局样式 app.setStyle("Fusion") app.setStyleSheet(""" QWidget { background-color: #252526; color: #D4D4D4; selection-background-color: #0078D7; selection-color: white; } QPushButton { background-color: #0078D7; color: white; border: none; padding: 5px 10px; border-radius: 4px; } QPushButton:hover { background-color: #106EBE; } QPushButton:pressed { background-color: #005A9E; } QComboBox, QLineEdit, QSpinBox, QDoubleSpinBox, QSlider { background-color: #3C3C40; color: #D4D4D4; border: 1px solid #3F3F46; border-radius: 4px; padding: 3px; } QComboBox:editable { background-color: #3C3C40; } QComboBox QAbstractItemView { background-color: #2D2D30; color: #D4D4D4; selection-background-color: #0078D7; selection-color: white; } QLabel { color: #D4D4D4; } QTabWidget::pane { border: 1px solid #3F3F46; background: #252526; } QTabBar::tab { background: #1E1E1E; color: #A0A0A0; padding: 8px 12px; border-top-left-radius: 4px; border-top-right-radius: 4px; } QTabBar::tab:selected { background: #252526; color: #FFFFFF; border-bottom: 2px solid #0078D7; } QTabBar::tab:hover { background: #2D2D30; } QGroupBox { background-color: #252526; border: 1px solid #3F3F46; border-radius: 4px; margin-top: 1ex; } QGroupBox::title { subcontrol-origin: margin; left: 10px; padding: 0 5px; background-color: transparent; color: #CCCCCC; } """) window = MainWindow(detector) window.show() sys.exit(app.exec()) 考虑所有问题,重构我的代码,最后给我完整的代码
07-16
根据原作 https://pan.quark.cn/s/459657bcfd45 的源码改编 Classic-ML-Methods-Algo 引言 建立这个项目,是为了梳理和总结传统机器学习(Machine Learning)方法(methods)或者算法(algo),和各位同仁相互学习交流. 现在的深度学习本质上来自于传统的神经网络模型,很大程度上是传统机器学习的延续,同时也在不少时候需要结合传统方法来实现. 任何机器学习方法基本的流程结构都是通用的;使用的评价方法也基本通用;使用的些数学知识也是通用的. 本文在梳理传统机器学习方法算法的同时也会顺便补充这些流程,数学上的知识以供参考. 机器学习 机器学习是人工智能(Artificial Intelligence)个分支,也是实现人工智能最重要的手段.区别于传统的基于规则(rule-based)的算法,机器学习可以从数据中获取知识,从而实现规定的任务[Ian Goodfellow and Yoshua Bengio and Aaron Courville的Deep Learning].这些知识可以分为四种: 总结(summarization) 预测(prediction) 估计(estimation) 假想验证(hypothesis testing) 机器学习主要关心的是预测[Varian在Big Data : New Tricks for Econometrics],预测的可以是连续性的输出变量,分类,聚类或者物品之间的有趣关联. 机器学习分类 根据数据配置(setting,是否有标签,可以是连续的也可以是离散的)和任务目标,我们可以将机器学习方法分为四种: 无监督(unsupervised) 训练数据没有给定...
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值