深度学习部署(十三): CUDA RunTime API thread_layout线程布局

文章详细介绍了如何在VSCode中配置CUDA-CPP支持,重点讨论了CUDA编程中的关键概念,如gridDim和blockDim的设置,以及线程束(warpsize)、最大线程数(maxThreadsPerBlock)等参数。通过示例代码解释了如何启动核函数,并展示了如何查询GPU设备属性以优化布局。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1. 知识点

  1. 在.vscode/settings.json中配置"*.cu": "cuda-cpp"可以实现对cuda的语法解析
    在这里插入图片描述

  2. layout是设置核函数执行的线程数,要明白最大值、block最大线程数、warpsize取值

    • maxGridSize对应gridDim的取值最大值
    • maxThreadsDim对应blockDim的取值最大值
    • warpSize对应线程束中的线程数量
    • maxThreadsPerBlock对应blockDim元素乘积最大值
  3. layout的4个主要变量的关系

    • gridDim是layout维度,其对应的索引是blockIdx
      • blockIdx的最大值是0到gridDim-1
    • blockDim是layout维度,其对应的索引是threadIdx
      • threadIdx的最大值是0到blockDim-1
      • blockDim维度乘积必须小于等于maxThreadsPerBlock
    • 所以称gridDim、blockDim为维度,启动核函数后是固定的
    • 所以称blockIdx、threadIdx为索引,启动核函数后,枚举每一个维度值,不同线程取值不同
    • 关于线程束带概念这里不讲,可以自行查询
  4. 核函数启动时,<<<>>>的参数分别为:<<<gridDim, blockDim, shraed_memory_size, cudaStream_t>>>

    • shared_memory_size请看后面关于shared memory的讲解,配置动态的shared memory大小

2. 图解知识点

.1 如何理解layout是设置核函数执行的线程数,要明白最大值、block最大线程数、warpsize取值?
请添加图片描述
layout就是一个一个grid,每个block里面有一堆的block, block里面放的是thread

Warp size指的是GPU中一个线程束内包含的线程数量,同时也是最小的调度单元,即GPU会将同一个warp中的线程一起调度,以便实现并行计算。

  1. 如果我这个GPU的warp size = 32, 最小的调度单元是32个线程是吗?

是的,每个SM内部的所有线程都被划分为以warp size为大小的warp,每个warp内的线程并行执行,并且最小的调度单元是一个warp,即32个线程。如果某个warp中的线程出现了分支或者条件语句,这些线程将会被分成不同的warp分别执行,这可能会导致性能下降。因此,尽量避免分支和条件语句可以提高GPU的执行效率。

  1. 案例计算:
    请添加图片描述
    这个案例里面的layout
  • girdDim.x = 3
  • gridDim.y = 2
  • gridDim.z = 1 (这里用的是默认值)
  • blockDim.x = 4
  • blockDim.y = 2
  • blockDim.z = 1 (这里用的是默认值)

从上面把数字带进去计算就可以得到黄色格子是13了。

3. main.cpp文件

#include <cuda_runtime.h>
#include <stdio.h>

#define checkRuntime(op)  __check_cuda_runtime((op), #op, __FILE__, __LINE__)

bool __check_cuda_runtime(cudaError_t code, const char* op, const char* file, int line){
    if(code != cudaSuccess){    
        const char* err_name = cudaGetErrorName(code);    
        const char* err_message = cudaGetErrorString(code);  
        printf("runtime error %s:%d  %s failed. \n  code = %s, message = %s\n", file, line, op, err_name, err_message);   
        return false;
    }
    return true;
}

void launch(int* grids, int* blocks);

int main(){

    cudaDeviceProp prop;
    checkRuntime(cudaGetDeviceProperties(&prop, 0));

    // 通过查询maxGridSize和maxThreadsDim参数,得知能够设计的gridDims、blockDims的最大值
    // warpSize则是线程束的线程数量
    // maxThreadsPerBlock则是一个block中能够容纳的最大线程数,也就是说blockDims[0] * blockDims[1] * blockDims[2] <= maxThreadsPerBlock
    printf("prop.maxGridSize = %d, %d, %d\n", prop.maxGridSize[0], prop.maxGridSize[1], prop.maxGridSize[2]);
    printf("prop.maxThreadsDim = %d, %d, %d\n", prop.maxThreadsDim[0], prop.maxThreadsDim[1], prop.maxThreadsDim[2]);
    printf("prop.warpSize = %d\n", prop.warpSize);
    printf("prop.maxThreadsPerBlock = %d\n", prop.maxThreadsPerBlock);

    int grids[] = {1, 2, 3};     // gridDim.x  gridDim.y  gridDim.z 
    int blocks[] = {1024, 1, 1}; // blockDim.x blockDim.y blockDim.z 
    // launch(grids, blocks);       // grids表示的是有几个大格子,blocks表示的是每个大格子里面有多少个小格子
    checkRuntime(cudaPeekAtLastError());   // 获取错误 code 但不清楚error
    checkRuntime(cudaDeviceSynchronize()); // 进行同步,这句话以上的代码全部可以异步操作
    printf("done\n");
    return 0;
}

4. cu文件

#include <cuda_runtime.h>
#include <stdio.h>

__global__ void demo_kernel(){

    if(blockIdx.x == 0 && threadIdx.x == 0)
        printf("Run kernel. blockIdx = %d,%d,%d  threadIdx = %d,%d,%d\n",
            blockIdx.x, blockIdx.y, blockIdx.z,
            threadIdx.x, threadIdx.y, threadIdx.z
        );
}

void launch(int* grids, int* blocks){

    dim3 grid_dims(grids[0], grids[1], grids[2]);
    dim3 block_dims(blocks[0], blocks[1], blocks[2]);
    demo_kernel<<<grid_dims, block_dims, 0, nullptr>>>();
}

5. 代码拆解

void search_demo()
{
    // 定义一个结构体用来储存设备信息
    // 返回一个指向0号设备的指针, 如果写1号设备但是没有,checkRuntime会报错
    cudaDeviceProp prop;
    checkRuntime(cudaGetDeviceProperties(&prop, 0)); 

    // 查询maxGrid的数量,也是看每一个维度能放多少个block
    printf("prop.maxGridSize = %d, %d, %d\n", prop.maxGridSize[0], 
    prop.maxGridSize[1], prop.maxGridSize[2]);

    // 查询每一个block不同维度的最大线程数,看能放多少个线程
    printf("prop.maxThreadsDim = %d, %d, %d\n", prop.maxThreadsDim[0], prop.maxThreadsDim[1], prop.maxThreadsDim[2]);
    
    // 查询warp size
    printf("prop.warpSize = %d\n", prop.warpSize);

    printf("prop.maxThreadsPerBlock = %d\n", prop.maxThreadsPerBlock);
}
prop.maxGridSize = 2147483647, 65535, 65535
prop.maxThreadsDim = 1024, 1024, 64
prop.warpSize = 32
prop.maxThreadsPerBlock = 1024

定义一个结构体用来储存设备信息

返回一个指向0号设备的指针, 如果写1号设备但是没有,checkRuntime会报错

查询maxGrid的数量,也是看每一个维度能放多少个block

查询每一个block不同维度的最大线程数,看能放多少个线程

查询warp size

6. cu文件解读

调用在main.cpp文件里面

main.cpp

int main(){
    search_demo(); // 查询设备信息,可以通过查询设备信息了解到

    // 布局的demo, 定义布局
    int grids[] = {1, 2, 3}; // girdDim.x, gridDim.y, gridDim.z
    int blocks[] = {1024, 1, 1}; // blockDim.x, blockDim.y, blockDim.z
    launch(grids, blocks);   // grids表示的是有几个大格子,blocks表示的是每个大格子里面有多少个小格子
    checkRuntime(cudaPeekAtLastError());   // 获取错误 code 但不清楚error
    checkRuntime(cudaDeviceSynchronize()); // 进行同步,这句话以上的代码全部可以异步操作
    return 0;
}

.cu

#include <cuda_runtime.h>
#include <stdio.h>

__global__ void demo_kernel()
{
    // 这个案例是去每一个grid里面的第一个block, 第一个block的第一个线程输出信息
    // 因为block是grid的索引,thread是block的索引
    if (blockIdx.x == 0 && threadIdx.x == 0)
    {
        printf("Run kernel. blockIdx = %d,%d,%d  threadIdx = %d,%d,%d\n",
               blockIdx.x, blockIdx.y, blockIdx.z,
               threadIdx.x, threadIdx.y, threadIdx.z);
    }
}

void launch(int* grids, int* blocks)
{
    dim3 gird_dims(grids[0], grids[1], grids[2]);
    dim3 blocks_dims(blocks[0], blocks[1], blocks[2]);
    demo_kernel<<<gird_dims, blocks_dims, 0, nullptr>>>();
}
prop.maxGridSize = 2147483647, 65535, 65535
prop.maxThreadsDim = 1024, 1024, 64
prop.warpSize = 32
prop.maxThreadsPerBlock = 1024
Run kernel. blockIdx = 0,0,1  threadIdx = 0,0,0
Run kernel. blockIdx = 0,1,2  threadIdx = 0,0,0
Run kernel. blockIdx = 0,1,1  threadIdx = 0,0,0
Run kernel. blockIdx = 0,0,2  threadIdx = 0,0,0
Run kernel. blockIdx = 0,0,0  threadIdx = 0,0,0
Run kernel. blockIdx = 0,1,0  threadIdx = 0,0,0

这里的 launch 函数中 blocks[0] 为 1024,其它两个维度为 1,表示一个 block 中有 1024 个线程。grids 数组表示了整个网格的大小,其中 grids[0] 表示 x 方向上有 1 个 block,grids[1] 表示 y 方向上有 2 个 block,grids[2] 表示 z 方向上有 3 个 block。因此,总共有 6 个 block,每个 block 有 1024 个线程,所以总共有 6144 个线程。

demo_kernel()

这个案例是去每一个grid里面的第一个block, 第一个block的第一个线程输出信息

因为block是grid的索引,thread是block的索引

import sys import cv2 import time import torch import traceback import threading import queue import dxcam import ctypes import os import glob import numpy as np import logitech.lg from PyQt6.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton, QComboBox, QSlider, QSpinBox, QDoubleSpinBox, QLineEdit, QTabWidget, QGroupBox, QTextEdit, QFileDialog, QMessageBox, QSizePolicy, QSplitter, QDialog, QScrollArea) from PyQt6.QtCore import Qt, QTimer, QThread, pyqtSignal from PyQt6.QtGui import QImage, QPixmap, QPainter, QColor, QFont, QIcon, QKeyEvent, QMouseEvent from PyQt6.QtSvg import QSvgRenderer from PIL import Image from ultralytics import YOLO from pynput import mouse class PIDController: """PID控制器""" def __init__(self, kp, ki, kd, output_min=-100, output_max=100): self.kp = kp # 比例增益 self.ki = ki # 积分增益 self.kd = kd # 微分增益 self.output_min = output_min self.output_max = output_max # 状态变量 self.integral = 0.0 self.prev_error = 0.0 self.last_time = time.perf_counter() def compute(self, setpoint, current_value): """计算PID控制输出""" current_time = time.perf_counter() dt = current_time - self.last_time # 防止过小的时间差导致计算问题 MIN_DT = 0.0001 if dt < MIN_DT: dt = MIN_DT # 计算误差 error = setpoint - current_value # 比例项 P = self.kp * error # 积分项(防饱和) self.integral += error * dt I = self.ki * self.integral # 微分项 derivative = (error - self.prev_error) / dt D = self.kd * derivative # 合成输出 output = P + I + D # 输出限幅 if output > self.output_max: output = self.output_max elif output < self.output_min: output = self.output_min # 更新状态 self.prev_error = error self.last_time = current_time return output def reset(self): """重置控制器状态""" self.integral = 0.0 self.prev_error = 0.0 self.last_time = time.perf_counter() class ScreenDetector: def __init__(self, config_path): # 解析配置文件 self._parse_config(config_path) # 设备检测与模型加载 self.device = self._determine_device() self.model = YOLO(self.model_path).to(self.device) # 屏幕信息初始化 self._init_screen_info() # 控制参数初始化 self._init_control_params() # 状态管理 self.stop_event = threading.Event() self.camera_lock = threading.Lock() self.target_lock = threading.Lock() self.offset_lock = threading.Lock() self.button_lock = threading.Lock() # 推理状态控制 self.inference_active = False self.inference_lock = threading.Lock() # 初始化相机 self._init_camera() # 初始化鼠标监听器 self._init_mouse_listener() # 初始化PID控制器 self._init_pid_controllers() def _parse_config(self, config_path): """解析并存储配置参数""" self.cfg = self._parse_txt_config(config_path) # 存储常用参数 self.model_path = self.cfg['model_path'] self.model_device = self.cfg['model_device'] self.screen_target_size = int(self.cfg['screen_target_size']) self.detection_conf_thres = float(self.cfg['detection_conf_thres']) self.detection_iou_thres = float(self.cfg['detection_iou_thres']) self.detection_classes = [int(x) for x in self.cfg['detection_classes'].split(',')] self.visualization_color = tuple(map(int, self.cfg['visualization_color'].split(','))) self.visualization_line_width = int(self.cfg['visualization_line_width']) self.visualization_font_scale = float(self.cfg['visualization_font_scale']) self.visualization_show_conf = bool(self.cfg['visualization_show_conf']) self.fov_horizontal = float(self.cfg.get('move_fov_horizontal', '90')) self.mouse_dpi = int(self.cfg.get('move_mouse_dpi', '400')) self.target_offset_x_percent = float(self.cfg.get('target_offset_x', '50')) self.target_offset_y_percent = 100 - float(self.cfg.get('target_offset_y', '50')) # PID参数 self.pid_kp = float(self.cfg.get('pid_kp', '1.0')) self.pid_ki = float(self.cfg.get('pid_ki', '0.05')) self.pid_kd = float(self.cfg.get('pid_kd', '0.2')) # 贝塞尔曲线参数 self.bezier_steps = int(self.cfg.get('bezier_steps', '100')) self.bezier_duration = float(self.cfg.get('bezier_duration', '0.1')) self.bezier_curve = float(self.cfg.get('bezier_curve', '0.3')) def update_config(self, config_path): """动态更新配置""" try: # 重新解析配置文件 self._parse_config(config_path) # 更新可以直接修改的参数 self.detection_conf_thres = float(self.cfg['detection_conf_thres']) self.detection_iou_thres = float(self.cfg['detection_iou_thres']) self.target_offset_x_percent = float(self.cfg.get('target_offset_x', '50')) self.target_offset_y_percent = 100 - float(self.cfg.get('target_offset_y', '50')) # PID参数更新 self.pid_kp = float(self.cfg.get('pid_kp', '1.0')) self.pid_ki = float(self.cfg.get('pid_ki', '0.05')) self.pid_kd = float(self.cfg.get('pid_kd', '0.2')) # 更新PID控制器 self.pid_x = PIDController(self.pid_kp, self.pid_ki, self.pid_kd) self.pid_y = PIDController(self.pid_kp, self.pid_ki, self.pid_kd) # FOV和DPI更新 self.fov_horizontal = float(self.cfg.get('move_fov_horizontal', '90')) self.mouse_dpi = int(self.cfg.get('move_mouse_dpi', '400')) # 更新贝塞尔曲线参数 self.bezier_steps = int(self.cfg.get('bezier_steps', '100')) self.bezier_duration = float(self.cfg.get('bezier_duration', '0.1')) self.bezier_curve = float(self.cfg.get('bezier_curve', '0.3')) print("配置已动态更新") return True except Exception as e: print(f"更新配置失败: {str(e)}") traceback.print_exc() return False def _parse_txt_config(self, path): """解析TXT格式的配置文件""" config = {} with open(path, 'r', encoding='utf-8') as f: for line in f: line = line.strip() if not line or line.startswith('#'): continue if '=' in line: key, value = line.split('=', 1) config[key.strip()] = value.strip() return config def _init_pid_controllers(self): """初始化PID控制器""" # 创建XY方向的PID控制器 self.pid_x = PIDController(self.pid_kp, self.pid_ki, self.pid_kd) self.pid_y = PIDController(self.pid_kp, self.pid_ki, self.pid_kd) def start_inference(self): """启动推理""" with self.inference_lock: self.inference_active = True def stop_inference(self): """停止推理""" with self.inference_lock: self.inference_active = False def _determine_device(self): """确定运行设备""" if self.model_device == 'auto': return 'cuda' if torch.cuda.is_available() and torch.cuda.device_count() > 0 else 'cpu' return self.model_device def _init_screen_info(self): """初始化屏幕信息""" user32 = ctypes.windll.user32 self.screen_width, self.screen_height = user32.GetSystemMetrics(0), user32.GetSystemMetrics(1) self.screen_center = (self.screen_width // 2, self.screen_height // 2) # 计算截图区域 left = (self.screen_width - self.screen_target_size) // 2 top = (self.screen_height - self.screen_target_size) // 2 self.region = ( max(0, int(left)), max(0, int(top)), min(self.screen_width, int(left + self.screen_target_size)), min(self.screen_height, int(top + self.screen_target_size)) ) def _init_control_params(self): """初始化控制参数""" self.previous_target_info = None self.closest_target_absolute = None self.target_offset = None self.right_button_pressed = False # 改为鼠标右键状态 def _init_camera(self): """初始化相机""" try: with self.camera_lock: self.camera = dxcam.create( output_idx=0, output_color="BGR", region=self.region ) self.camera.start(target_fps=120, video_mode=True) except Exception as e: print(f"相机初始化失败: {str(e)}") try: # 降级模式 with self.camera_lock: self.camera = dxcam.create() self.camera.start(target_fps=60, video_mode=True) except Exception as fallback_e: print(f"降级模式初始化失败: {str(fallback_e)}") self.camera = None def _init_mouse_listener(self): """初始化鼠标监听器""" self.mouse_listener = mouse.Listener( on_click=self.on_mouse_click # 监听鼠标点击事件 ) self.mouse_listener.daemon = True self.mouse_listener.start() def on_mouse_click(self, x, y, button, pressed): """处理鼠标点击事件""" try: if button == mouse.Button.right: # 监听鼠标右键 with self.button_lock: self.right_button_pressed = pressed # 更新状态 # 当右键释放时重置PID控制器 if not pressed: self.pid_x.reset() self.pid_y.reset() except Exception as e: print(f"鼠标事件处理错误: {str(e)}") def calculate_fov_movement(self, dx, dy): """基于FOV算法计算鼠标移动量""" # 计算屏幕对角线长度 screen_diagonal = (self.screen_width ** 2 + self.screen_height ** 2) ** 0.5 # 计算垂直FOV aspect_ratio = self.screen_width / self.screen_height fov_vertical = self.fov_horizontal / aspect_ratio # 计算每像素对应角度 angle_per_pixel_x = self.fov_horizontal / self.screen_width angle_per_pixel_y = fov_vertical / self.screen_height # 计算角度偏移 angle_offset_x = dx * angle_per_pixel_x angle_offset_y = dy * angle_per_pixel_y # 转换为鼠标移动量 move_x = (angle_offset_x / 360) * self.mouse_dpi move_y = (angle_offset_y / 360) * self.mouse_dpi return move_x, move_y def move_mouse_to_target(self): """移动鼠标对准目标点""" if not self.target_offset: return try: # 获取目标点与屏幕中心的偏移量 with self.offset_lock: dx, dy = self.target_offset # 使用FOV算法将像素偏移转换为鼠标移动量 move_x, move_y = self.calculate_fov_movement(dx, dy) # 使用PID计算平滑的移动量 pid_move_x = self.pid_x.compute(0, -move_x) # 将dx取反 pid_move_y = self.pid_y.compute(0, -move_y) # 将dy取反 # 移动鼠标 if pid_move_x != 0 or pid_move_y != 0: logitech.lg.start_mouse_move(int(pid_move_x), int(pid_move_y), self.bezier_steps, self.bezier_duration, self.bezier_curve) except Exception as e: print(f"移动鼠标时出错: {str(e)}") def run(self, frame_queue): """主检测循环""" while not self.stop_event.is_set(): try: # 检查推理状态 with self.inference_lock: if not self.inference_active: time.sleep(0.01) continue # 截图 grab_start = time.perf_counter() screenshot = self._grab_screenshot() grab_time = (time.perf_counter() - grab_start) * 1000 # ms if screenshot is None: time.sleep(0.001) continue # 推理 inference_start = time.perf_counter() results = self._inference(screenshot) inference_time = (time.perf_counter() - inference_start) * 1000 # ms # 处理检测结果 target_info, closest_target_relative, closest_offset = self._process_detection_results(results) # 更新目标信息 self._update_target_info(target_info, closest_offset) # 移动鼠标 self._move_mouse_if_needed() # 可视化处理 annotated_frame = self._visualize_results(results, closest_target_relative) if frame_queue else None # 放入队列 if frame_queue: try: frame_queue.put( (annotated_frame, len(target_info), inference_time, grab_time, target_info), timeout=0.01 ) except queue.Full: pass except Exception as e: print(f"检测循环异常: {str(e)}") traceback.print_exc() self._reset_camera() time.sleep(0.5) def _grab_screenshot(self): """安全获取截图""" with self.camera_lock: if self.camera: return self.camera.grab() return None def _inference(self, screenshot): """执行模型推理""" return self.model.predict( screenshot, conf=self.detection_conf_thres, iou=self.detection_iou_thres, classes=self.detection_classes, device=self.device, verbose=False ) def _process_detection_results(self, results): """处理检测结果""" target_info = [] min_distance = float('inf') closest_target_relative = None closest_target_absolute = None closest_offset = None for box in results[0].boxes: # 获取边界框坐标 x1, y1, x2, y2 = map(int, box.xyxy[0]) # 计算绝对坐标 x1_abs = x1 + self.region[0] y1_abs = y1 + self.region[1] x2_abs = x2 + self.region[0] y2_abs = y2 + self.region[1] # 计算边界框尺寸 width = x2_abs - x1_abs height = y2_abs - y1_abs # 应用偏移百分比计算目标点 target_x = x1_abs + int(width * (self.target_offset_x_percent / 100)) target_y = y1_abs + int(height * (self.target_offset_y_percent / 100)) # 计算偏移量 dx = target_x - self.screen_center[0] dy = target_y - self.screen_center[1] distance = (dx ** 2 + dy ** 2) ** 0.5 # 更新最近目标 if distance < min_distance: min_distance = distance # 计算相对坐标(用于可视化) closest_target_relative = ( x1 + int(width * (self.target_offset_x_percent / 100)), y1 + int(height * (self.target_offset_y_percent / 100)) ) closest_target_absolute = (target_x, target_y) closest_offset = (dx, dy) # 保存目标信息 class_id = int(box.cls) class_name = self.model.names[class_id] target_info.append(f"{class_name}:{x1_abs},{y1_abs},{x2_abs},{y2_abs}") return target_info, closest_target_relative, closest_offset def _update_target_info(self, target_info, closest_offset): """更新目标信息""" # 检查目标信息是否有变化 if target_info != self.previous_target_info: self.previous_target_info = target_info.copy() print(f"{len(target_info)}|{'|'.join(target_info)}") # 更新目标偏移量 with self.offset_lock: self.target_offset = closest_offset def _visualize_results(self, results, closest_target): """可视化处理结果""" frame = results[0].plot( line_width=self.visualization_line_width, font_size=self.visualization_font_scale, conf=self.visualization_show_conf ) # 绘制最近目标 if closest_target: # 绘制目标中心点 cv2.circle( frame, (int(closest_target[0]), int(closest_target[1])), 3, (0, 0, 255), -1 ) # 计算屏幕中心在截图区域内的相对坐标 screen_center_x = self.screen_center[0] - self.region[0] screen_center_y = self.screen_center[1] - self.region[1] # 绘制中心到目标的连线 cv2.line( frame, (int(screen_center_x), int(screen_center_y)), (int(closest_target[0]), int(closest_target[1])), (0, 255, 0), 1 ) return frame def _move_mouse_if_needed(self): """如果需要则移动鼠标""" with self.button_lock: if self.right_button_pressed and self.target_offset: # 使用right_button_pressed self.move_mouse_to_target() def _reset_camera(self): """重置相机""" print("正在重置相机...") try: self._init_camera() except Exception as e: print(f"相机重置失败: {str(e)}") traceback.print_exc() def stop(self): """安全停止检测器""" self.stop_event.set() self._safe_stop() if hasattr(self, 'mouse_listener') and self.mouse_listener.running: # 改为停止鼠标监听器 self.mouse_listener.stop() def _safe_stop(self): """同步释放资源""" print("正在安全停止相机...") try: with self.camera_lock: if self.camera: self.camera.stop() print("相机已停止") except Exception as e: print(f"停止相机时发生错误: {str(e)}") print("屏幕检测器已停止") class DetectionThread(QThread): update_signal = pyqtSignal(object) def __init__(self, detector, frame_queue): super().__init__() self.detector = detector self.frame_queue = frame_queue self.running = True def run(self): self.detector.run(self.frame_queue) def stop(self): self.running = False self.detector.stop() class MainWindow(QMainWindow): def __init__(self, detector): super().__init__() self.detector = detector self.setWindowTitle("EFAI 1.1") self.setGeometry(100, 100, 600, 400) # 添加缺失的属性初始化 self.visualization_enabled = True self.inference_active = False # 初始推理状态为停止 #窗口置顶 self.setWindowFlag(Qt.WindowType.WindowStaysOnTopHint) # 创建帧队列 self.frame_queue = queue.Queue(maxsize=3) # 初始化UI self.init_ui() # 启动检测线程 self.detection_thread = DetectionThread(self.detector, self.frame_queue) self.detection_thread.start() # 启动UI更新定时器 self.update_timer = QTimer() self.update_timer.timeout.connect(self.update_ui) self.update_timer.start(1) # 每1ms更新一次 def toggle_visualization(self): # 实际更新可视化状态属性 self.visualization_enabled = not self.visualization_enabled # 更新按钮文本 if self.visualization_enabled: self.toggle_visualization_btn.setText("禁用可视化") else: self.toggle_visualization_btn.setText("启用可视化") def toggle_inference(self): """切换推理状态""" self.inference_active = not self.inference_active if self.inference_active: self.toggle_inference_btn.setText("停止推理") self.toggle_inference_btn.setStyleSheet(""" QPushButton { background-color: #F44336; color: white; border: none; padding: 8px; border-radius: 4px; font-family: Segoe UI; font-size: 10pt; } """) self.detector.start_inference() else: self.toggle_inference_btn.setText("开始推理") self.toggle_inference_btn.setStyleSheet(""" QPushButton { background-color: #4CAF50; color: white; border: none; padding: 8px; border-radius: 4px; font-family: Segoe UI; font-size: 10pt; } """) self.detector.stop_inference() def init_ui(self): # 主布局 central_widget = QWidget() self.setCentralWidget(central_widget) main_layout = QVBoxLayout(central_widget) # 分割器(左侧图像/目标信息,右侧控制面板) splitter = QSplitter(Qt.Orientation.Horizontal) main_layout.addWidget(splitter) # 左侧区域(图像显示和目标信息) left_widget = QWidget() left_layout = QVBoxLayout(left_widget) # 图像显示区域 self.image_label = QLabel() self.image_label.setAlignment(Qt.AlignmentFlag.AlignCenter) self.image_label.setMinimumSize(320, 320) left_layout.addWidget(self.image_label) # 目标信息区域 self.target_info_text = QTextEdit() self.target_info_text.setReadOnly(True) self.target_info_text.setFixedHeight(150) self.target_info_text.setStyleSheet(""" QTextEdit { background-color: #2D2D30; color: #DCDCDC; font-family: Consolas; font-size: 10pt; border: 1px solid #3F3F46; border-radius: 4px; } """) left_layout.addWidget(self.target_info_text) # 右侧控制面板 right_widget = QWidget() right_layout = QVBoxLayout(right_widget) right_layout.setAlignment(Qt.AlignmentFlag.AlignTop) # 性能信息 perf_group = QGroupBox("性能信息") perf_layout = QVBoxLayout(perf_group) self.target_count_label = QLabel("目标数量: 0") self.inference_time_label = QLabel("推理时间: 0.000s") self.grab_time_label = QLabel("截图时间: 0.000s") for label in [self.target_count_label, self.inference_time_label, self.grab_time_label]: label.setStyleSheet("font-family: Consolas; font-size: 10pt;") perf_layout.addWidget(label) right_layout.addWidget(perf_group) # 系统信息 sys_group = QGroupBox("系统信息") sys_layout = QVBoxLayout(sys_group) # 获取模型名称(只显示文件名) model_name = os.path.basename(self.detector.model_path) # 获取显示器编号(如果配置中有则显示,否则显示默认值0) monitor_index = self.detector.cfg.get('screen_monitor', '0') self.model_label = QLabel(f"模型: {model_name}") self.device_label = QLabel(f"设备: {self.detector.device.upper()}") self.monitor_label = QLabel(f"显示器:{monitor_index}") self.screen_res_label = QLabel(f"屏幕分辨率: {self.detector.screen_width}x{self.detector.screen_height}") self.region_label = QLabel(f"检测区域: {self.detector.region}") for label in [self.model_label, self.device_label, self.monitor_label, self.screen_res_label, self.region_label]: label.setStyleSheet("font-family: Consolas; font-size: 9pt; color: #A0A0A0;") sys_layout.addWidget(label) right_layout.addWidget(sys_group) # 鼠标状态 mouse_group = QGroupBox("自瞄状态") mouse_layout = QVBoxLayout(mouse_group) self.mouse_status = QLabel("未瞄准") self.mouse_status.setStyleSheet(""" QLabel { font-family: Consolas; font-size: 10pt; color: #FF5252; } """) mouse_layout.addWidget(self.mouse_status) right_layout.addWidget(mouse_group) # 控制按钮 btn_group = QGroupBox("控制") btn_layout = QVBoxLayout(btn_group) # 添加推理切换按钮 self.toggle_inference_btn = QPushButton("开始推理") self.toggle_inference_btn.clicked.connect(self.toggle_inference) self.toggle_inference_btn.setStyleSheet(""" QPushButton { background-color: #4CAF50; color: white; border: none; padding: 8px; border-radius: 4px; font-family: Segoe UI; font-size: 10pt; } QPushButton:hover { background-color: #45A049; } QPushButton:pressed { background-color: #3D8B40; } """) btn_layout.addWidget(self.toggle_inference_btn) self.toggle_visualization_btn = QPushButton("禁用可视化") self.toggle_visualization_btn.clicked.connect(self.toggle_visualization) self.settings_btn = QPushButton("设置") self.settings_btn.clicked.connect(self.open_settings) for btn in [self.toggle_visualization_btn, self.settings_btn]: btn.setStyleSheet(""" QPushButton { background-color: #0078D7; color: white; border: none; padding: 8px; border-radius: 4px; font-family: Segoe UI; font-size: 10pt; } QPushButton:hover { background-color: #106EBE; } QPushButton:pressed { background-color: #005A9E; } """) btn_layout.addWidget(btn) right_layout.addWidget(btn_group) # 添加左右区域到分割器 splitter.addWidget(left_widget) splitter.addWidget(right_widget) splitter.setSizes([600, 200]) # 设置样式 self.setStyleSheet(""" QMainWindow { background-color: #252526; } QGroupBox { font-family: Segoe UI; font-size: 10pt; color: #CCCCCC; border: 1px solid #3F3F46; border-radius: 4px; margin-top: 1ex; } QGroupBox::title { subcontrol-origin: margin; left: 10px; padding: 0 5px; background-color: transparent; } """) def open_settings(self): settings_dialog = SettingsDialog(self.detector.cfg, self) settings_dialog.exec() def update_ui(self): try: # 获取最新数据 latest_data = None while not self.frame_queue.empty(): latest_data = self.frame_queue.get_nowait() if latest_data: # 解包数据 frame, targets_count, inference_time, grab_time, target_info = latest_data # 更新性能信息 self.target_count_label.setText(f"目标数量: {targets_count}") self.inference_time_label.setText(f"推理时间: {inference_time / 1000:.3f}s") self.grab_time_label.setText(f"截图时间: {grab_time / 1000:.3f}s") # 更新目标信息 self.display_target_info(target_info) # 更新图像显示 if self.visualization_enabled and frame is not None: # 转换图像为Qt格式 height, width, channel = frame.shape bytes_per_line = 3 * width q_img = QImage(frame.data, width, height, bytes_per_line, QImage.Format.Format_BGR888) pixmap = QPixmap.fromImage(q_img) # 等比例缩放 scaled_pixmap = pixmap.scaled( self.image_label.width(), self.image_label.height(), Qt.AspectRatioMode.KeepAspectRatio, Qt.TransformationMode.SmoothTransformation ) self.image_label.setPixmap(scaled_pixmap) else: # 显示黑色背景 pixmap = QPixmap(self.image_label.size()) pixmap.fill(QColor(0, 0, 0)) self.image_label.setPixmap(pixmap) # 更新鼠标状态 self.update_mouse_status() except Exception as e: print(f"更新UI时出错: {str(e)}") def display_target_info(self, target_info): """在文本框中显示目标信息""" if not target_info: self.target_info_text.setPlainText("无检测目标") return info_text = "目标类别与坐标:\n" for i, data in enumerate(target_info): try: parts = data.split(":", 1) if len(parts) == 2: class_name, coords_str = parts coords = list(map(int, coords_str.split(','))) if len(coords) == 4: display_text = f"{class_name}: [{coords[0]}, {coords[1]}, {coords[2]}, {coords[3]}]" else: display_text = f"坐标格式错误: {data}" else: display_text = f"数据格式错误: {data}" except: display_text = f"解析错误: {data}" info_text += f"{display_text}\n" self.target_info_text.setPlainText(info_text) def update_mouse_status(self): """更新鼠标右键状态显示""" with self.detector.button_lock: if self.detector.right_button_pressed: self.mouse_status.setText("瞄准中") self.mouse_status.setStyleSheet("color: #4CAF50; font-family: Consolas; font-size: 10pt;") else: self.mouse_status.setText("未瞄准") self.mouse_status.setStyleSheet("color: #FF5252; font-family: Consolas; font-size: 10pt;") def closeEvent(self, event): """安全关闭程序""" self.detection_thread.stop() self.detection_thread.wait() event.accept() class SettingsDialog(QDialog): def __init__(self, config, parent=None): super().__init__(parent) self.config = config # 保存原始配置的副本用于比较 self.original_config = config.copy() self.setWindowTitle("设置") self.setGeometry(100, 100, 600, 500) self.init_ui() def init_ui(self): layout = QVBoxLayout() self.setLayout(layout) # 标签页 tabs = QTabWidget() layout.addWidget(tabs) # 检测设置标签页 detection_tab = QWidget() detection_layout = QVBoxLayout(detection_tab) self.create_detection_settings(detection_layout) tabs.addTab(detection_tab, "检测") # 移动设置标签页 move_tab = QWidget() move_layout = QVBoxLayout(move_tab) self.create_move_settings(move_layout) tabs.addTab(move_tab, "FOV") # 目标点设置标签页 target_tab = QWidget() target_layout = QVBoxLayout(target_tab) self.create_target_settings(target_layout) tabs.addTab(target_tab, "目标点") # PID设置标签页 pid_tab = QWidget() pid_layout = QVBoxLayout(pid_tab) self.create_pid_settings(pid_layout) tabs.addTab(pid_tab, "PID") # 贝塞尔曲线设置标签页 bezier_tab = QWidget() bezier_layout = QVBoxLayout(bezier_tab) self.create_bezier_settings(bezier_layout) tabs.addTab(bezier_tab, "贝塞尔曲线") # 按钮区域 btn_layout = QHBoxLayout() layout.addLayout(btn_layout) save_btn = QPushButton("保存配置") save_btn.clicked.connect(self.save_config) cancel_btn = QPushButton("取消") cancel_btn.clicked.connect(self.reject) for btn in [save_btn, cancel_btn]: btn.setStyleSheet(""" QPushButton { background-color: #0078D7; color: white; border: none; padding: 8px 16px; border-radius: 4px; font-family: Segoe UI; font-size: 10pt; } QPushButton:hover { background-color: #106EBE; } QPushButton:pressed { background-color: #005A9E; } """) btn_layout.addWidget(btn) btn_layout.addStretch() def create_detection_settings(self, layout): # 模型选择 model_group = QGroupBox("模型设置") model_layout = QVBoxLayout(model_group) # 获取基础路径 if getattr(sys, 'frozen', False): base_path = sys._MEIPASS else: base_path = os.path.dirname(os.path.abspath(__file__)) # 获取模型文件列表 models_dir = os.path.join(base_path, 'models') model_files = [] if os.path.exists(models_dir): model_files = glob.glob(os.path.join(models_dir, '*.pt')) # 处理模型显示名称 model_display_names = [os.path.basename(f) for f in model_files] if model_files else ["未找到模型文件"] self.model_name_to_path = {os.path.basename(f): f for f in model_files} # 当前配置的模型处理 current_model_path = self.config['model_path'] current_model_name = os.path.basename(current_model_path) # 确保当前模型在列表中 if current_model_name not in model_display_names: model_display_names.append(current_model_name) self.model_name_to_path[current_model_name] = current_model_path # 模型选择下拉框 model_layout.addWidget(QLabel("选择模型:")) self.model_combo = QComboBox() self.model_combo.addItems(model_display_names) self.model_combo.setCurrentText(current_model_name) model_layout.addWidget(self.model_combo) # 设备选择 model_layout.addWidget(QLabel("运行设备:")) self.device_combo = QComboBox() self.device_combo.addItems(['auto', 'cuda', 'cpu']) self.device_combo.setCurrentText(self.config['model_device']) model_layout.addWidget(self.device_combo) layout.addWidget(model_group) # 检测参数 param_group = QGroupBox("检测参数") param_layout = QVBoxLayout(param_group) # 置信度阈值 param_layout.addWidget(QLabel("置信度阈值:")) conf_layout = QHBoxLayout() self.conf_slider = QSlider(Qt.Orientation.Horizontal) self.conf_slider.setRange(10, 100) # 0.1到1.0,步长0.01 self.conf_slider.setValue(int(float(self.config['detection_conf_thres']) * 100)) conf_layout.addWidget(self.conf_slider) self.conf_value = QLabel(f"{float(self.config['detection_conf_thres']):.2f}") self.conf_value.setFixedWidth(50) conf_layout.addWidget(self.conf_value) param_layout.addLayout(conf_layout) # 连接滑块值变化事件 self.conf_slider.valueChanged.connect(lambda value: self.conf_value.setText(f"{value / 100:.2f}")) # IOU阈值 - 改为滑动条 param_layout.addWidget(QLabel("IOU阈值:")) iou_layout = QHBoxLayout() self.iou_slider = QSlider(Qt.Orientation.Horizontal) self.iou_slider.setRange(10, 100) # 0.1到1.0,步长0.01 self.iou_slider.setValue(int(float(self.config['detection_iou_thres']) * 100)) iou_layout.addWidget(self.iou_slider) self.iou_value = QLabel(f"{float(self.config['detection_iou_thres']):.2f}") self.iou_value.setFixedWidth(50) iou_layout.addWidget(self.iou_value) param_layout.addLayout(iou_layout) # 连接滑块值变化事件 self.iou_slider.valueChanged.connect(lambda value: self.iou_value.setText(f"{value / 100:.2f}")) # 检测类别 param_layout.addWidget(QLabel("检测类别 (逗号分隔):")) self.classes_edit = QLineEdit() self.classes_edit.setText(self.config['detection_classes']) param_layout.addWidget(self.classes_edit) layout.addWidget(param_group) # 屏幕设置 screen_group = QGroupBox("屏幕设置") screen_layout = QVBoxLayout(screen_group) # 显示器编号 screen_layout.addWidget(QLabel("显示器编号:")) self.monitor_spin = QSpinBox() self.monitor_spin.setRange(0, 3) # 假设最多支持4个显示器 self.monitor_spin.setValue(int(self.config.get('screen_monitor', '0'))) screen_layout.addWidget(self.monitor_spin) # 屏幕区域大小 screen_layout.addWidget(QLabel("截屏尺寸:")) self.screen_size_spin = QSpinBox() self.screen_size_spin.setRange(100, 2000) self.screen_size_spin.setValue(int(self.config['screen_target_size'])) screen_layout.addWidget(self.screen_size_spin) layout.addWidget(screen_group) layout.addStretch() def create_move_settings(self, layout): group = QGroupBox("鼠标移动参数") group_layout = QVBoxLayout(group) # FOV设置 group_layout.addWidget(QLabel("横向FOV():")) self.fov_spin = QDoubleSpinBox() self.fov_spin.setRange(1, 179) self.fov_spin.setValue(float(self.config.get('move_fov_horizontal', '90'))) group_layout.addWidget(self.fov_spin) # 鼠标DPI group_layout.addWidget(QLabel("鼠标DPI:")) self.dpi_spin = QSpinBox() self.dpi_spin.setRange(100, 20000) self.dpi_spin.setValue(int(self.config.get('move_mouse_dpi', '400'))) group_layout.addWidget(self.dpi_spin) layout.addWidget(group) layout.addStretch() def create_target_settings(self, layout): group = QGroupBox("目标点偏移") group_layout = QVBoxLayout(group) # X轴偏移 - 添加百分比显示 group_layout.addWidget(QLabel("X轴偏移:")) x_layout = QHBoxLayout() self.x_offset_slider = QSlider(Qt.Orientation.Horizontal) self.x_offset_slider.setRange(0, 100) self.x_offset_slider.setValue(int(float(self.config.get('target_offset_x', '50')))) x_layout.addWidget(self.x_offset_slider) self.x_offset_value = QLabel(f"{int(float(self.config.get('target_offset_x', '50')))}%") self.x_offset_value.setFixedWidth(50) x_layout.addWidget(self.x_offset_value) group_layout.addLayout(x_layout) # 连接滑块值变化事件 self.x_offset_slider.valueChanged.connect(lambda value: self.x_offset_value.setText(f"{value}%")) # Y轴偏移 - 添加百分比显示 group_layout.addWidget(QLabel("Y轴偏移:")) y_layout = QHBoxLayout() self.y_offset_slider = QSlider(Qt.Orientation.Horizontal) self.y_offset_slider.setRange(0, 100) self.y_offset_slider.setValue(int(float(self.config.get('target_offset_y', '50')))) y_layout.addWidget(self.y_offset_slider) self.y_offset_value = QLabel(f"{int(float(self.config.get('target_offset_y', '50')))}%") self.y_offset_value.setFixedWidth(50) y_layout.addWidget(self.y_offset_value) group_layout.addLayout(y_layout) # 连接滑块值变化事件 self.y_offset_slider.valueChanged.connect(lambda value: self.y_offset_value.setText(f"{value}%")) # 说明 info_label = QLabel("(0% = 左上角, 50% = 中心, 100% = 右下角)") info_label.setStyleSheet("font-size: 9pt; color: #888888;") group_layout.addWidget(info_label) layout.addWidget(group) layout.addStretch() def create_pid_settings(self, layout): group = QGroupBox("PID参数") group_layout = QVBoxLayout(group) # Kp参数 group_layout.addWidget(QLabel("比例增益(Kp):")) kp_layout = QHBoxLayout() self.kp_slider = QSlider(Qt.Orientation.Horizontal) self.kp_slider.setRange(1, 1000) # 0.01到10.0,步长0.01 self.kp_slider.setValue(int(float(self.config.get('pid_kp', '1.0')) * 100)) kp_layout.addWidget(self.kp_slider) self.kp_value = QLabel(f"{float(self.config.get('pid_kp', '1.0')):.2f}") self.kp_value.setFixedWidth(50) kp_layout.addWidget(self.kp_value) group_layout.addLayout(kp_layout) # 连接滑块值变化事件 self.kp_slider.valueChanged.connect(lambda value: self.kp_value.setText(f"{value / 100:.2f}")) # Ki参数 group_layout.addWidget(QLabel("积分增益(Ki):")) ki_layout = QHBoxLayout() self.ki_slider = QSlider(Qt.Orientation.Horizontal) self.ki_slider.setRange(0, 100) # 0.0000到0.1000,步长0.001 self.ki_slider.setValue(int(float(self.config.get('pid_ki', '0.05')) * 10000)) ki_layout.addWidget(self.ki_slider) self.ki_value = QLabel(f"{float(self.config.get('pid_ki', '0.05')):.4f}") self.ki_value.setFixedWidth(50) ki_layout.addWidget(self.ki_value) group_layout.addLayout(ki_layout) # 连接滑块值变化事件 self.ki_slider.valueChanged.connect(lambda value: self.ki_value.setText(f"{value / 10000:.4f}")) # Kd参数 group_layout.addWidget(QLabel("微分增益(Kd):")) kd_layout = QHBoxLayout() self.kd_slider = QSlider(Qt.Orientation.Horizontal) self.kd_slider.setRange(0, 5000) # 0.000到5.000,步长0.001 self.kd_slider.setValue(int(float(self.config.get('pid_kd', '0.2')) * 1000)) kd_layout.addWidget(self.kd_slider) self.kd_value = QLabel(f"{float(self.config.get('pid_kd', '0.2')):.3f}") self.kd_value.setFixedWidth(50) kd_layout.addWidget(self.kd_value) group_layout.addLayout(kd_layout) # 连接滑块值变化事件 self.kd_slider.valueChanged.connect(lambda value: self.kd_value.setText(f"{value / 1000:.3f}")) # 说明 info_text = "建议调整顺序: Kp → Kd → Ki\n\n" \ "先调整Kp至响应迅速但不过冲\n" \ "再增加Kd抑制震荡\n" \ "最后微调Ki消除剩余误差" info_label = QLabel(info_text) info_label.setStyleSheet("font-size: 9pt; color: #888888;") group_layout.addWidget(info_label) layout.addWidget(group) layout.addStretch() # 创建贝塞尔曲线设置 def create_bezier_settings(self, layout): group = QGroupBox("贝塞尔曲线参数") group_layout = QVBoxLayout(group) # 步数设置 group_layout.addWidget(QLabel("步数 (1-500):")) steps_layout = QHBoxLayout() self.steps_slider = QSlider(Qt.Orientation.Horizontal) self.steps_slider.setRange(1, 500) self.steps_slider.setValue(int(self.config.get('bezier_steps', 100))) steps_layout.addWidget(self.steps_slider) self.steps_value = QLabel(str(self.config.get('bezier_steps', 100))) self.steps_value.setFixedWidth(50) steps_layout.addWidget(self.steps_value) group_layout.addLayout(steps_layout) # 连接滑块值变化事件 self.steps_slider.valueChanged.connect(lambda value: self.steps_value.setText(str(value))) # 总移动时间设置 () group_layout.addWidget(QLabel("总移动时间 (秒, 0.01-1.0):")) duration_layout = QHBoxLayout() self.duration_slider = QSlider(Qt.Orientation.Horizontal) self.duration_slider.setRange(1, 100) # 0.01到1.0,步长0.01 self.duration_slider.setValue(int(float(self.config.get('bezier_duration', 0.1)) * 100)) duration_layout.addWidget(self.duration_slider) self.duration_value = QLabel(f"{float(self.config.get('bezier_duration', 0.1)):.2f}") self.duration_value.setFixedWidth(50) duration_layout.addWidget(self.duration_value) group_layout.addLayout(duration_layout) # 连接滑块值变化事件 self.duration_slider.valueChanged.connect(lambda value: self.duration_value.setText(f"{value / 100:.2f}")) # 控制点偏移幅度 group_layout.addWidget(QLabel("控制点偏移幅度 (0-1):")) curve_layout = QHBoxLayout() self.curve_slider = QSlider(Qt.Orientation.Horizontal) self.curve_slider.setRange(0, 100) # 0.00到1.00,步长0.01 self.curve_slider.setValue(int(float(self.config.get('bezier_curve', 0.3)) * 100)) curve_layout.addWidget(self.curve_slider) self.curve_value = QLabel(f"{float(self.config.get('bezier_curve', 0.3)):.2f}") self.curve_value.setFixedWidth(50) curve_layout.addWidget(self.curve_value) group_layout.addLayout(curve_layout) # 连接滑块值变化事件 self.curve_slider.valueChanged.connect(lambda value: self.curve_value.setText(f"{value / 100:.2f}")) # 说明 info_text = "贝塞尔曲线参数说明:\n\n" \ "• 步数: 鼠标移动的细分步数,值越大移动越平滑\n" \ "• 总移动时间: 鼠标移动的总时间,值越小移动越快\n" \ "• 控制点偏移幅度: 控制贝塞尔曲线的弯曲程度,0为直线,1为最大弯曲" info_label = QLabel(info_text) info_label.setStyleSheet("font-size: 9pt; color: #888888;") group_layout.addWidget(info_label) layout.addWidget(group) layout.addStretch() def save_config(self): try: # 保存配置到字典 model_name = self.model_combo.currentText() model_path = self.model_name_to_path.get(model_name, model_name) self.config['model_path'] = model_path self.config['model_device'] = self.device_combo.currentText() self.config['screen_monitor'] = str(self.monitor_spin.value()) self.config['screen_target_size'] = str(self.screen_size_spin.value()) # 检测参数 self.config['detection_conf_thres'] = str(self.conf_slider.value() / 100) self.config['detection_iou_thres'] = str(self.iou_slider.value() / 100) self.config['detection_classes'] = self.classes_edit.text() # 移动设置 self.config['move_fov_horizontal'] = str(self.fov_spin.value()) self.config['move_mouse_dpi'] = str(self.dpi_spin.value()) # 目标点偏移设置 self.config['target_offset_x'] = str(self.x_offset_slider.value()) self.config['target_offset_y'] = str(self.y_offset_slider.value()) # PID设置 self.config['pid_kp'] = str(self.kp_slider.value() / 100) self.config['pid_ki'] = str(self.ki_slider.value() / 10000) self.config['pid_kd'] = str(self.kd_slider.value() / 1000) # 贝塞尔曲线设置 self.config['bezier_steps'] = str(self.steps_slider.value()) self.config['bezier_duration'] = str(self.duration_slider.value() / 100) self.config['bezier_curve'] = str(self.curve_slider.value() / 100) # 保存为TXT格式 with open('detection_config.txt', 'w', encoding='utf-8') as f: for key, value in self.config.items(): f.write(f"{key} = {value}\n") # 检查需要重启的参数是否被修改 restart_required = False restart_params = [] # 比较模型路径是否变化 if self.config['model_path'] != self.original_config.get('model_path', ''): restart_required = True restart_params.append("模型路径") # 比较设备类型是否变化 if self.config['model_device'] != self.original_config.get('model_device', ''): restart_required = True restart_params.append("设备类型") # 比较屏幕区域大小是否变化 if self.config['screen_target_size'] != self.original_config.get('screen_target_size', ''): restart_required = True restart_params.append("屏幕区域大小") # 比较检测类别是否变化 if self.config['detection_classes'] != self.original_config.get('detection_classes', ''): restart_required = True restart_params.append("检测类别") # 动态更新检测器配置 if self.parent() and hasattr(self.parent(), 'detector'): success = self.parent().detector.update_config('detection_config.txt') if success: if restart_required: # 需要重启的参数已修改 param_list = "、".join(restart_params) QMessageBox.information( self, "配置已保存", f"配置已保存!以下参数需要重启才能生效:\n{param_list}\n\n" "其他参数已实时更新。" ) else: # 所有参数都已实时更新 QMessageBox.information(self, "成功", "配置已实时更新生效!") else: QMessageBox.warning(self, "部分更新", "配置更新失败,请查看日志") else: QMessageBox.information(self, "成功", "配置已保存!部分参数需重启生效") self.accept() except Exception as e: QMessageBox.critical(self, "错误", f"保存配置失败: {str(e)}") if __name__ == "__main__": detector = ScreenDetector('detection_config.txt') print(f"\nDXcam检测器初始化完成 | 设备: {detector.device.upper()}") app = QApplication(sys.argv) # 设置全局样式 app.setStyle("Fusion") app.setStyleSheet(""" QWidget { background-color: #252526; color: #D4D4D4; selection-background-color: #0078D7; selection-color: white; } QPushButton { background-color: #0078D7; color: white; border: none; padding: 5px 10px; border-radius: 4px; } QPushButton:hover { background-color: #106EBE; } QPushButton:pressed { background-color: #005A9E; } QComboBox, QLineEdit, QSpinBox, QDoubleSpinBox, QSlider { background-color: #3C3C40; color: #D4D4D4; border: 1px solid #3F3F46; border-radius: 4px; padding: 3px; } QComboBox:editable { background-color: #3C3C40; } QComboBox QAbstractItemView { background-color: #2D2D30; color: #D4D4D4; selection-background-color: #0078D7; selection-color: white; } QLabel { color: #D4D4D4; } QTabWidget::pane { border: 1px solid #3F3F46; background: #252526; } QTabBar::tab { background: #1E1E1E; color: #A0A0A0; padding: 8px 12px; border-top-left-radius: 4px; border-top-right-radius: 4px; } QTabBar::tab:selected { background: #252526; color: #FFFFFF; border-bottom: 2px solid #0078D7; } QTabBar::tab:hover { background: #2D2D30; } QGroupBox { background-color: #252526; border: 1px solid #3F3F46; border-radius: 4px; margin-top: 1ex; } QGroupBox::title { subcontrol-origin: margin; left: 10px; padding: 0 5px; background-color: transparent; color: #CCCCCC; } """) window = MainWindow(detector) window.show() sys.exit(app.exec()) 考虑所有问题,重构我的代码,最后给我完整的代码
最新发布
07-16
检查代码是否可运行,是否高效,是否可CPUimport sys import os import json import time import wave import numpy as np import pandas as pd import matplotlib.pyplot as plt import soundfile as sf from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas from PyQt5.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QLabel, QLineEdit, QTextEdit, QFileDialog, QProgressBar, QGroupBox, QComboBox, QCheckBox, QMessageBox) from PyQt5.QtCore import QThread, pyqtSignal from pydub import AudioSegment from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification import whisper from pyannote.audio import Pipeline from docx import Document from docx.shared import Inches import librosa import tempfile from collections import defaultdict import re from concurrent.futures import ThreadPoolExecutor, as_completed import torch from torch.cuda import is_available as cuda_available import logging import gc # 配置日志 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # 全局模型缓存 MODEL_CACHE = {} class AnalysisThread(QThread): progress = pyqtSignal(int) message = pyqtSignal(str) analysis_complete = pyqtSignal(dict) error = pyqtSignal(str) def __init__(self, audio_files, keyword_file, whisper_model_path, pyannote_model_path, emotion_model_path): super().__init__() self.audio_files = audio_files self.keyword_file = keyword_file self.whisper_model_path = whisper_model_path self.pyannote_model_path = pyannote_model_path self.emotion_model_path = emotion_model_path self.running = True self.cached_models = {} self.temp_files = [] # 用于管理临时文件 self.lock = torch.multiprocessing.Lock() # 用于模型加载的锁 def run(self): try: # 加载关键词 self.message.emit("正在加载关键词...") keywords = self.load_keywords() # 预加载模型 self.message.emit("正在预加载模型...") self.preload_models() results = [] total_files = len(self.audio_files) for idx, audio_file in enumerate(self.audio_files): if not self.running: self.message.emit("分析已停止") return self.message.emit(f"正在处理文件: {os.path.basename(audio_file)} ({idx + 1}/{total_files})") file_result = self.analyze_file(audio_file, keywords) if file_result: results.append(file_result) # 定期清理内存 if idx % 5 == 0: gc.collect() torch.cuda.empty_cache() if cuda_available() else None self.progress.emit(int((idx + 1) / total_files * 100)) self.analysis_complete.emit({"results": results, "keywords": keywords}) self.message.emit("分析完成!") except Exception as e: import traceback error_msg = f"分析过程中发生错误: {str(e)}\n{traceback.format_exc()}" self.error.emit(error_msg) logger.error(error_msg) finally: # 清理临时文件 self.cleanup_temp_files() def cleanup_temp_files(self): """清理所有临时文件""" for temp_file in self.temp_files: if os.path.exists(temp_file): try: os.unlink(temp_file) except Exception as e: logger.warning(f"删除临时文件失败: {temp_file}, 原因: {str(e)}") def preload_models(self): """预加载所有模型到缓存(添加线程安全)""" global MODEL_CACHE # 使用锁确保线程安全 with self.lock: # 检查全局缓存是否已加载模型 if 'whisper' in MODEL_CACHE and 'pyannote' in MODEL_CACHE and 'emotion_classifier' in MODEL_CACHE: self.cached_models = MODEL_CACHE self.message.emit("使用缓存的模型") return self.cached_models = {} try: # 加载语音识别模型 if 'whisper' not in MODEL_CACHE: self.message.emit("正在加载语音识别模型...") MODEL_CACHE['whisper'] = whisper.load_model( self.whisper_model_path, device="cuda" if cuda_available() else "cpu" ) self.cached_models['whisper'] = MODEL_CACHE['whisper'] # 加载说话人分离模型 if 'pyannote' not in MODEL_CACHE: self.message.emit("正在加载说话人分离模型...") MODEL_CACHE['pyannote'] = Pipeline.from_pretrained( self.pyannote_model_path, use_auth_token=True ) self.cached_models['pyannote'] = MODEL_CACHE['pyannote'] # 加载情感分析模型 if 'emotion_classifier' not in MODEL_CACHE: self.message.emit("正在加载情感分析模型...") device = 0 if cuda_available() else -1 tokenizer = AutoTokenizer.from_pretrained(self.emotion_model_path) model = AutoModelForSequenceClassification.from_pretrained(self.emotion_model_path) # 尝试使用半精度浮点数减少内存占用 try: if device != -1: model = model.half() except Exception: pass # 如果失败则继续使用全精度 MODEL_CACHE['emotion_classifier'] = pipeline( "text-classification", model=model, tokenizer=tokenizer, device=device ) self.cached_models['emotion_classifier'] = MODEL_CACHE['emotion_classifier'] except Exception as e: raise Exception(f"模型加载失败: {str(e)}") def analyze_file(self, audio_file, keywords): """分析单个音频文件(优化内存使用)""" try: # 确保音频为WAV格式 wav_file, is_temp = self.convert_to_wav(audio_file) if is_temp: self.temp_files.append(wav_file) # 获取音频信息 duration, sample_rate, channels = self.get_audio_info(wav_file) # 说话人分离 - 使用较小的音频片段处理大文件 diarization = self.process_diarization(wav_file, duration) # 识别客服和客户 agent_segments, customer_segments = self.identify_speakers(wav_file, diarization, keywords['opening']) # 并行处理客服和客户音频 agent_result, customer_result = {}, {} with ThreadPoolExecutor(max_workers=2) as executor: agent_future = executor.submit( self.process_speaker_audio, wav_file, agent_segments, "客服" ) customer_future = executor.submit( self.process_speaker_audio, wav_file, customer_segments, "客户" ) agent_result = agent_future.result() customer_result = customer_future.result() # 情感分析 - 批处理提高效率 agent_emotion, customer_emotion = self.analyze_emotions( [agent_result.get('text', ''), customer_result.get('text', '')] ) # 服务规范检查 opening_check = self.check_opening(agent_result.get('text', ''), keywords['opening']) closing_check = self.check_closing(agent_result.get('text', ''), keywords['closing']) forbidden_check = self.check_forbidden(agent_result.get('text', ''), keywords['forbidden']) # 沟通技巧分析 speech_rate = self.analyze_speech_rate(agent_result.get('segments', [])) volume_analysis = self.analyze_volume(wav_file, agent_segments, sample_rate) # 问题解决率分析 resolution_rate = self.analyze_resolution( agent_result.get('text', ''), customer_result.get('text', ''), keywords['resolution'] ) return { "file_name": os.path.basename(audio_file), "duration": duration, "agent_text": agent_result.get('text', ''), "customer_text": customer_result.get('text', ''), "opening_check": opening_check, "closing_check": closing_check, "forbidden_check": forbidden_check, "agent_emotion": agent_emotion, "customer_emotion": customer_emotion, "speech_rate": speech_rate, "volume_mean": volume_analysis.get('mean', -60), "volume_std": volume_analysis.get('std', 0), "resolution_rate": resolution_rate } except Exception as e: error_msg = f"处理文件 {os.path.basename(audio_file)} 时出错: {str(e)}" self.error.emit(error_msg) logger.error(error_msg, exc_info=True) return None finally: # 清理临时文件 if is_temp and os.path.exists(wav_file): try: os.unlink(wav_file) except Exception: pass def process_diarization(self, wav_file, duration): """分块处理说话人分离,避免大文件内存溢出""" # 对于短音频直接处理 if duration <= 600: # 10分钟以下 return self.cached_models['pyannote'](wav_file) # 对于长音频分块处理 self.message.emit(f"音频较长({duration:.1f}秒),将分块处理...") diarization_result = [] chunk_size = 300 # 5分钟块 for start in range(0, int(duration), chunk_size): if not self.running: return [] end = min(start + chunk_size, duration) self.message.emit(f"处理片段: {start}-{end}秒") # 提取音频片段 with tempfile.NamedTemporaryFile(suffix='.wav') as tmpfile: self.extract_audio_segment(wav_file, start, end, tmpfile.name) segment_diarization = self.cached_models['pyannote'](tmpfile.name) # 调整时间偏移 for segment, _, speaker in segment_diarization.itertracks(yield_label=True): diarization_result.append(( segment.start + start, segment.end + start, speaker )) return diarization_result def extract_audio_segment(self, input_file, start_sec, end_sec, output_file): """提取音频片段""" audio = AudioSegment.from_wav(input_file) start_ms = int(start_sec * 1000) end_ms = int(end_sec * 1000) segment = audio[start_ms:end_ms] segment.export(output_file, format="wav") def process_speaker_audio(self, wav_file, segments, speaker_type): """处理说话人音频(优化内存使用)""" if not segments: return {'text': "", 'segments': []} text = "" segment_details = [] whisper_model = self.cached_models['whisper'] # 处理每个片段 for idx, (start, end) in enumerate(segments): if not self.running: break # 每处理5个片段报告一次进度 if idx % 5 == 0: self.message.emit(f"{speaker_type}: 处理片段 {idx+1}/{len(segments)}") duration = end - start segment_text = self.transcribe_audio_segment(wav_file, start, end, whisper_model) segment_details.append({ 'start': start, 'end': end, 'duration': duration, 'text': segment_text }) text += segment_text + " " return { 'text': text.strip(), 'segments': segment_details } def identify_speakers(self, wav_file, diarization, opening_keywords): """ 改进的客服识别方法 1. 检查前三个片段是否有开场白关键词 2. 如果片段不足三个,则检查所有存在的片段 3. 如果无法确定客服,则默认第二个说话人是客服 """ if not diarization: return [], [] speaker_segments = defaultdict(list) speaker_first_occurrence = {} # 记录每个说话人的首次出现时间 # 收集所有说话人片段并记录首次出现时间 for item in diarization: if len(item) == 3: # 来自分块处理的结果 start, end, speaker = item else: # 来自pyannote的直接结果 segment, _, speaker = item start, end = segment.start, segment.end speaker_segments[speaker].append((start, end)) if speaker not in speaker_first_occurrence or start < speaker_first_occurrence[speaker]: speaker_first_occurrence[speaker] = start # 如果没有说话人 if not speaker_segments: return [], [] # 如果只有一个说话人 if len(speaker_segments) == 1: speaker = list(speaker_segments.keys())[0] return speaker_segments[speaker], [] # 计算每个说话人的开场白得分 speaker_scores = {} whisper_model = self.cached_models['whisper'] for speaker, segments in speaker_segments.items(): score = 0 # 检查前三个片段(如果存在) check_segments = segments[:3] # 最多取前三个片段 for start, end in check_segments: # 转录片段 text = self.transcribe_audio_segment(wav_file, start, end, whisper_model) # 检查开场白关键词 for keyword in opening_keywords: if keyword and keyword in text: score += 1 break # 找到一个关键词就加分并跳出循环 speaker_scores[speaker] = score # 尝试找出得分最高的说话人 max_score = max(speaker_scores.values()) max_speakers = [spk for spk, score in speaker_scores.items() if score == max_score] # 如果有唯一最高分说话人,作为客服 if len(max_speakers) == 1: agent_speaker = max_speakers[0] else: # 无法通过开场白确定客服时,默认第二个说话人是客服 # 按首次出现时间排序 sorted_speakers = sorted(speaker_first_occurrence.items(), key=lambda x: x[1]) # 确保至少有两个说话人 if len(sorted_speakers) >= 2: # 取时间上第二个出现的说话人 agent_speaker = sorted_speakers[1][0] else: # 如果只有一个说话人(理论上不会进入此分支,但安全处理) agent_speaker = sorted_speakers[0][0] # 分离客服和客户片段 agent_segments = speaker_segments[agent_speaker] customer_segments = [] for speaker, segments in speaker_segments.items(): if speaker != agent_speaker: customer_segments.extend(segments) return agent_segments, customer_segments def load_keywords(self): """从Excel文件加载关键词(增强健壮性)""" try: df = pd.read_excel(self.keyword_file) # 确保列存在 columns = ['opening', 'closing', 'forbidden', 'resolution'] for col in columns: if col not in df.columns: raise ValueError(f"关键词文件缺少必要列: {col}") keywords = { "opening": [str(k).strip() for k in df['opening'].dropna().tolist() if str(k).strip()], "closing": [str(k).strip() for k in df['closing'].dropna().tolist() if str(k).strip()], "forbidden": [str(k).strip() for k in df['forbidden'].dropna().tolist() if str(k).strip()], "resolution": [str(k).strip() for k in df['resolution'].dropna().tolist() if str(k).strip()] } # 检查是否有足够的关键词 if not any(keywords.values()): raise ValueError("关键词文件中没有找到有效关键词") return keywords except Exception as e: raise Exception(f"加载关键词文件失败: {str(e)}") def convert_to_wav(self, audio_file): """将音频文件转换为WAV格式(增强健壮性)""" try: if not os.path.exists(audio_file): raise FileNotFoundError(f"音频文件不存在: {audio_file}") if audio_file.lower().endswith('.wav'): return audio_file, False # 使用临时文件避免磁盘IO with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmpfile: output_file = tmpfile.name audio = AudioSegment.from_file(audio_file) audio.export(output_file, format='wav') return output_file, True except Exception as e: raise Exception(f"音频转换失败: {str(e)}") def get_audio_info(self, wav_file): """获取音频文件信息(增强健壮性)""" try: if not os.path.exists(wav_file): raise FileNotFoundError(f"音频文件不存在: {wav_file}") # 使用soundfile获取更可靠的信息 with sf.SoundFile(wav_file) as f: duration = len(f) / f.samplerate sample_rate = f.samplerate channels = f.channels return duration, sample_rate, channels except Exception as e: raise Exception(f"获取音频信息失败: {str(e)}") def transcribe_audio_segment(self, wav_file, start, end, model): """转录单个音频片段 - 优化内存使用""" # 使用pydub加载音频 audio = AudioSegment.from_wav(wav_file) # 转换为毫秒 start_ms = int(start * 1000) end_ms = int(end * 1000) segment_audio = audio[start_ms:end_ms] # 使用临时文件 with tempfile.NamedTemporaryFile(suffix='.wav') as tmpfile: segment_audio.export(tmpfile.name, format="wav") try: result = model.transcribe( tmpfile.name, fp16=cuda_available() # 使用FP16加速(如果可用) ) return result['text'] except RuntimeError as e: if "out of memory" in str(e).lower(): # 尝试释放内存后重试 torch.cuda.empty_cache() gc.collect() result = model.transcribe( tmpfile.name, fp16=cuda_available() ) return result['text'] raise def analyze_emotions(self, texts): """批量分析文本情感(提高效率)""" if not any(t.strip() for t in texts): return [{"label": "中性", "score": 0.0} for _ in texts] # 截断长文本以提高性能 processed_texts = [t[:500] if len(t) > 500 else t for t in texts] # 批量处理 classifier = self.cached_models['emotion_classifier'] results = classifier(processed_texts, truncation=True, max_length=512, batch_size=4) # 确保返回格式一致 emotions = [] for result in results: if isinstance(result, list) and result: emotions.append({ "label": result[0]['label'], "score": result[0]['score'] }) else: emotions.append({ "label": "中性", "score": 0.0 }) return emotions def check_opening(self, text, opening_keywords): """检查开场白(使用正则表达式提高准确性)""" if not text or not opening_keywords: return False pattern = "|".join(re.escape(k) for k in opening_keywords) return bool(re.search(pattern, text)) def check_closing(self, text, closing_keywords): """检查结束语(使用正则表达式提高准确性)""" if not text or not closing_keywords: return False pattern = "|".join(re.escape(k) for k in closing_keywords) return bool(re.search(pattern, text)) def check_forbidden(self, text, forbidden_keywords): """检查服务禁语(使用正则表达式提高准确性)""" if not text or not forbidden_keywords: return False pattern = "|".join(re.escape(k) for k in forbidden_keywords) return bool(re.search(pattern, text)) def analyze_speech_rate(self, segments): """改进的语速分析 - 基于实际识别文本""" if not segments: return 0 total_chars = 0 total_duration = 0 for segment in segments: # 计算片段时长(秒) duration = segment['duration'] total_duration += duration # 计算中文字符数(去除标点和空格) chinese_chars = sum(1 for char in segment['text'] if '\u4e00' <= char <= '\u9fff') total_chars += chinese_chars if total_duration == 0: return 0 # 语速 = 总字数 / 总时长(分钟) return total_chars / (total_duration / 60) def analyze_volume(self, wav_file, segments, sample_rate): """改进的音量分析 - 使用librosa计算RMS分贝值""" if not segments: return {"mean": -60, "std": 0} # 使用soundfile加载音频(更高效) try: y, sr = sf.read(wav_file, dtype='float32') if sr != sample_rate: y = librosa.resample(y, orig_sr=sr, target_sr=sample_rate) sr = sample_rate except Exception: # 回退到librosa y, sr = librosa.load(wav_file, sr=sample_rate, mono=True) all_dB = [] for start, end in segments: start_sample = int(start * sr) end_sample = int(end * sr) # 确保片段在有效范围内 if start_sample < len(y) and end_sample <= len(y): segment_audio = y[start_sample:end_sample] # 计算RMS并转换为dB rms = librosa.feature.rms(y=segment_audio)[0] dB = librosa.amplitude_to_db(rms, ref=1.0) # 使用标准参考值 all_dB.extend(dB) if not all_dB: return {"mean": -60, "std": 0} return { "mean": float(np.mean(all_dB)), "std": float(np.std(all_dB)) } def analyze_resolution(self, agent_text, customer_text, resolution_keywords): """分析问题解决率(使用更智能的匹配)""" # 检查客户是否提到问题 problem_patterns = [ "问题", "故障", "解决", "怎么办", "如何", "为什么", "不行", "不能", "无法", "错误", "bug", "issue", "疑问", "咨询" ] problem_regex = re.compile("|".join(problem_patterns)) has_problem = bool(problem_regex.search(customer_text)) # 检查客服是否提供解决方案 solution_regex = re.compile("|".join(re.escape(k) for k in resolution_keywords)) solution_found = bool(solution_regex.search(agent_text)) # 如果没有检测到问题,则认为已解决 if not has_problem: return True return solution_found def stop(self): """停止分析""" self.running = False self.message.emit("正在停止分析...") class MainWindow(QMainWindow): def __init__(self): super().__init__() self.setWindowTitle("外呼电话录音包质检分析系统") self.setGeometry(100, 100, 1000, 700) self.setStyleSheet(""" QMainWindow { background-color: #f0f0f0; } QGroupBox { font-weight: bold; border: 1px solid gray; border-radius: 5px; margin-top: 1ex; } QGroupBox::title { subcontrol-origin: margin; left: 10px; padding: 0 5px; } QPushButton { background-color: #4CAF50; color: white; border: none; padding: 5px 10px; border-radius: 3px; } QPushButton:hover { background-color: #45a049; } QPushButton:disabled { background-color: #cccccc; } QProgressBar { border: 1px solid grey; border-radius: 3px; text-align: center; } QProgressBar::chunk { background-color: #4CAF50; width: 10px; } QTextEdit { font-family: Consolas, Monaco, monospace; } """) # 初始化变量 self.audio_files = [] self.keyword_file = "" self.whisper_model_path = "./models/whisper-small" self.pyannote_model_path = "./models/pyannote-speaker-diarization" self.emotion_model_path = "./models/Erlangshen-Roberta-110M-Sentiment" self.output_dir = os.path.expanduser("~/质检报告") # 创建主控件 central_widget = QWidget() self.setCentralWidget(central_widget) main_layout = QVBoxLayout(central_widget) main_layout.setSpacing(10) main_layout.setContentsMargins(15, 15, 15, 15) # 文件选择区域 file_group = QGroupBox("文件选择") file_layout = QVBoxLayout(file_group) file_layout.setSpacing(8) # 音频文件选择 audio_layout = QHBoxLayout() self.audio_label = QLabel("音频文件/文件夹:") audio_layout.addWidget(self.audio_label) self.audio_path_edit = QLineEdit() self.audio_path_edit.setPlaceholderText("请选择音频文件或文件夹") audio_layout.addWidget(self.audio_path_edit, 3) self.audio_browse_btn = QPushButton("浏览...") self.audio_browse_btn.clicked.connect(self.browse_audio) audio_layout.addWidget(self.audio_browse_btn) file_layout.addLayout(audio_layout) # 关键词文件选择 keyword_layout = QHBoxLayout() self.keyword_label = QLabel("关键词文件:") keyword_layout.addWidget(self.keyword_label) self.keyword_path_edit = QLineEdit() self.keyword_path_edit.setPlaceholderText("请选择Excel格式的关键词文件") keyword_layout.addWidget(self.keyword_path_edit, 3) self.keyword_browse_btn = QPushButton("浏览...") self.keyword_browse_btn.clicked.connect(self.browse_keyword) keyword_layout.addWidget(self.keyword_browse_btn) file_layout.addLayout(keyword_layout) main_layout.addWidget(file_group) # 模型设置区域 model_group = QGroupBox("模型设置") model_layout = QVBoxLayout(model_group) model_layout.setSpacing(8) # Whisper模型路径 whisper_layout = QHBoxLayout() whisper_layout.addWidget(QLabel("Whisper模型路径:")) self.whisper_edit = QLineEdit(self.whisper_model_path) whisper_layout.addWidget(self.whisper_edit, 3) model_layout.addLayout(whisper_layout) # Pyannote模型路径 pyannote_layout = QHBoxLayout() pyannote_layout.addWidget(QLabel("Pyannote模型路径:")) self.pyannote_edit = QLineEdit(self.pyannote_model_path) pyannote_layout.addWidget(self.pyannote_edit, 3) model_layout.addLayout(pyannote_layout) # 情感分析模型路径 emotion_layout = QHBoxLayout() emotion_layout.addWidget(QLabel("情感分析模型路径:")) self.emotion_edit = QLineEdit(self.emotion_model_path) emotion_layout.addWidget(self.emotion_edit, 3) model_layout.addLayout(emotion_layout) # 输出目录 output_layout = QHBoxLayout() output_layout.addWidget(QLabel("输出目录:")) self.output_edit = QLineEdit(self.output_dir) self.output_edit.setPlaceholderText("请选择报告输出目录") output_layout.addWidget(self.output_edit, 3) self.output_browse_btn = QPushButton("浏览...") self.output_browse_btn.clicked.connect(self.browse_output) output_layout.addWidget(self.output_browse_btn) model_layout.addLayout(output_layout) main_layout.addWidget(model_group) # 控制按钮区域 control_layout = QHBoxLayout() control_layout.setSpacing(10) self.start_btn = QPushButton("开始分析") self.start_btn.setStyleSheet("background-color: #2196F3;") self.start_btn.clicked.connect(self.start_analysis) control_layout.addWidget(self.start_btn) self.stop_btn = QPushButton("停止分析") self.stop_btn.setStyleSheet("background-color: #f44336;") self.stop_btn.clicked.connect(self.stop_analysis) self.stop_btn.setEnabled(False) control_layout.addWidget(self.stop_btn) self.clear_btn = QPushButton("清空") self.clear_btn.clicked.connect(self.clear_all) control_layout.addWidget(self.clear_btn) main_layout.addLayout(control_layout) # 进度条 self.progress_bar = QProgressBar() self.progress_bar.setValue(0) self.progress_bar.setFormat("就绪") self.progress_bar.setMinimumHeight(25) main_layout.addWidget(self.progress_bar) # 日志输出区域 log_group = QGroupBox("分析日志") log_layout = QVBoxLayout(log_group) self.log_text = QTextEdit() self.log_text.setReadOnly(True) log_layout.addWidget(self.log_text) main_layout.addWidget(log_group, 1) # 给日志区域更多空间 # 状态区域 status_layout = QHBoxLayout() self.status_label = QLabel("状态: 就绪") status_layout.addWidget(self.status_label, 1) self.file_count_label = QLabel("已选择0个音频文件") status_layout.addWidget(self.file_count_label) main_layout.addLayout(status_layout) # 初始化分析线程 self.analysis_thread = None def browse_audio(self): """浏览音频文件或文件夹""" options = QFileDialog.Options() files, _ = QFileDialog.getOpenFileNames( self, "选择音频文件", "", "音频文件 (*.mp3 *.wav *.amr *.ogg *.flac *.m4a);;所有文件 (*)", options=options ) if files: self.audio_files = files self.audio_path_edit.setText("; ".join(files)) self.file_count_label.setText(f"已选择{len(files)}个音频文件") self.log_text.append(f"已选择{len(files)}个音频文件") def browse_keyword(self): """浏览关键词文件""" options = QFileDialog.Options() file, _ = QFileDialog.getOpenFileName( self, "选择关键词文件", "", "Excel文件 (*.xlsx *.xls);;所有文件 (*)", options=options ) if file: self.keyword_file = file self.keyword_path_edit.setText(file) self.log_text.append(f"已选择关键词文件: {file}") def browse_output(self): """浏览输出目录""" options = QFileDialog.Options() directory = QFileDialog.getExistingDirectory( self, "选择输出目录", self.output_dir, options=options ) if directory: self.output_dir = directory self.output_edit.setText(directory) self.log_text.append(f"输出目录设置为: {directory}") def start_analysis(self): """开始分析""" if not self.audio_files: self.show_warning("请先选择音频文件") return if not self.keyword_file: self.show_warning("请先选择关键词文件") return if not os.path.exists(self.keyword_file): self.show_warning("关键词文件不存在,请重新选择") return # 检查模型路径 model_paths = [ self.whisper_edit.text(), self.pyannote_edit.text(), self.emotion_edit.text() ] for path in model_paths: if not os.path.exists(path): self.show_warning(f"模型路径不存在: {path}") return # 更新模型路径 self.whisper_model_path = self.whisper_edit.text() self.pyannote_model_path = self.pyannote_edit.text() self.emotion_model_path = self.emotion_edit.text() self.output_dir = self.output_edit.text() # 创建输出目录 os.makedirs(self.output_dir, exist_ok=True) self.log_text.append("开始分析...") self.start_btn.setEnabled(False) self.stop_btn.setEnabled(True) self.status_label.setText("状态: 分析中...") self.progress_bar.setFormat("分析中... 0%") self.progress_bar.setValue(0) # 创建并启动分析线程 self.analysis_thread = AnalysisThread( self.audio_files, self.keyword_file, self.whisper_model_path, self.pyannote_model_path, self.emotion_model_path ) self.analysis_thread.progress.connect(self.update_progress) self.analysis_thread.message.connect(self.log_text.append) self.analysis_thread.analysis_complete.connect(self.on_analysis_complete) self.analysis_thread.error.connect(self.on_analysis_error) self.analysis_thread.finished.connect(self.on_analysis_finished) self.analysis_thread.start() def update_progress(self, value): """更新进度条""" self.progress_bar.setValue(value) self.progress_bar.setFormat(f"分析中... {value}%") def stop_analysis(self): """停止分析""" if self.analysis_thread and self.analysis_thread.isRunning(): self.analysis_thread.stop() self.log_text.append("正在停止分析...") self.stop_btn.setEnabled(False) def clear_all(self): """清空所有内容""" self.audio_files = [] self.keyword_file = "" self.audio_path_edit.clear() self.keyword_path_edit.clear() self.log_text.clear() self.progress_bar.setValue(0) self.progress_bar.setFormat("就绪") self.status_label.setText("状态: 就绪") self.file_count_label.setText("已选择0个音频文件") self.log_text.append("已清空所有内容") def show_warning(self, message): """显示警告消息""" QMessageBox.warning(self, "警告", message) self.log_text.append(f"警告: {message}") def on_analysis_complete(self, result): """分析完成处理""" try: self.log_text.append("正在生成报告...") if not result.get("results"): self.log_text.append("警告: 没有生成任何分析结果") return # 生成Excel报告 excel_path = os.path.join(self.output_dir, "质检分析报告.xlsx") self.generate_excel_report(result, excel_path) # 生成Word报告 word_path = os.path.join(self.output_dir, "质检分析报告.docx") self.generate_word_report(result, word_path) self.log_text.append(f"分析报告已保存至: {excel_path}") self.log_text.append(f"可视化报告已保存至: {word_path}") self.log_text.append("分析完成!") self.status_label.setText(f"状态: 分析完成!报告保存至: {self.output_dir}") self.progress_bar.setFormat("分析完成!") # 显示完成消息 QMessageBox.information( self, "分析完成", f"分析完成!报告已保存至:\n{excel_path}\n{word_path}" ) except Exception as e: import traceback error_msg = f"生成报告时出错: {str(e)}\n{traceback.format_exc()}" self.log_text.append(error_msg) logger.error(error_msg) def on_analysis_error(self, message): """分析错误处理""" self.log_text.append(f"错误: {message}") self.status_label.setText("状态: 发生错误") self.progress_bar.setFormat("发生错误") QMessageBox.critical(self, "分析错误", message) def on_analysis_finished(self): """分析线程结束处理""" self.start_btn.setEnabled(True) self.stop_btn.setEnabled(False) def generate_excel_report(self, result, output_path): """生成Excel报告(增强健壮性)""" try: # 从结果中提取数据 data = [] for res in result['results']: data.append({ "文件名": res['file_name'], "音频时长()": res['duration'], "开场白检查": "通过" if res['opening_check'] else "未通过", "结束语检查": "通过" if res['closing_check'] else "未通过", "服务禁语检查": "通过" if not res['forbidden_check'] else "未通过", "客服情感": res['agent_emotion']['label'], "客服情感得分": res['agent_emotion']['score'], "客户情感": res['customer_emotion']['label'], "客户情感得分": res['customer_emotion']['score'], "语速(字/分)": res['speech_rate'], "平均音量(dB)": res['volume_mean'], "音量标准差": res['volume_std'], "问题解决率": "是" if res['resolution_rate'] else "否" }) # 创建DataFrame并保存 df = pd.DataFrame(data) # 尝试使用openpyxl引擎(更稳定) try: df.to_excel(output_path, index=False, engine='openpyxl') except ImportError: df.to_excel(output_path, index=False) # 添加汇总统计 try: with pd.ExcelWriter(output_path, engine='openpyxl', mode='a', if_sheet_exists='replace') as writer: summary_data = { "统计项": ["总文件数", "开场白通过率", "结束语通过率", "服务禁语通过率", "问题解决率"], "数值": [ len(result['results']), df['开场白检查'].value_counts().get('通过', 0) / len(df), df['结束语检查'].value_counts().get('通过', 0) / len(df), df['服务禁语检查'].value_counts().get('通过', 0) / len(df), df['问题解决率'].value_counts().get('是', 0) / len(df) ] } summary_df = pd.DataFrame(summary_data) summary_df.to_excel(writer, sheet_name='汇总统计', index=False) except Exception as e: self.log_text.append(f"添加汇总统计时出错: {str(e)}") except Exception as e: raise Exception(f"生成Excel报告失败: {str(e)}") def generate_word_report(self, result, output_path): """生成Word报告(增强健壮性)""" try: doc = Document() # 添加标题 doc.add_heading('外呼电话录音质检分析报告', 0) # 添加基本信息 doc.add_heading('分析概况', level=1) doc.add_paragraph(f"分析时间: {time.strftime('%Y-%m-%d %H:%M:%S')}") doc.add_paragraph(f"分析文件数量: {len(result['results'])}") doc.add_paragraph(f"关键词文件: {os.path.basename(self.keyword_file)}") # 添加汇总统计 doc.add_heading('汇总统计', level=1) # 创建汇总表格 table = doc.add_table(rows=5, cols=2) table.style = 'Table Grid' # 表头 hdr_cells = table.rows[0].cells hdr_cells[0].text = '统计项' hdr_cells[1].text = '数值' # 计算统计数据 df = pd.DataFrame(result['results']) pass_rates = { "开场白通过率": df['opening_check'].mean() if not df.empty else 0, "结束语通过率": df['closing_check'].mean() if not df.empty else 0, "服务禁语通过率": (1 - df['forbidden_check']).mean() if not df.empty else 0, "问题解决率": df['resolution_rate'].mean() if not df.empty else 0 } # 填充表格 rows = [ ("总文件数", len(result['results'])), ("开场白通过率", f"{pass_rates['开场白通过率']:.2%}"), ("结束语通过率", f"{pass_rates['结束语通过率']:.2%}"), ("服务禁语通过率", f"{pass_rates['服务禁语通过率']:.2%}"), ("问题解决率", f"{pass_rates['问题解决率']:.2%}") ] for i, row_data in enumerate(rows): if i < len(table.rows): row_cells = table.rows[i].cells row_cells[0].text = row_data[0] row_cells[1].text = str(row_data[1]) # 添加情感分析图表 if result['results']: doc.add_heading('情感分析', level=1) # 客服情感分布 agent_emotions = [res['agent_emotion']['label'] for res in result['results']] agent_emotion_counts = pd.Series(agent_emotions).value_counts() if not agent_emotion_counts.empty: fig, ax = plt.subplots(figsize=(6, 4)) agent_emotion_counts.plot.pie(autopct='%1.1f%%', ax=ax) ax.set_title('客服情感分布') ax.set_ylabel('') # 移除默认的ylabel plt.tight_layout() # 保存图表到临时文件 chart_path = os.path.join(self.output_dir, "agent_emotion_chart.png") plt.savefig(chart_path, dpi=100, bbox_inches='tight') plt.close() doc.add_picture(chart_path, width=Inches(4)) doc.add_paragraph('图1: 客服情感分布') # 客户情感分布 customer_emotions = [res['customer_emotion']['label'] for res in result['results']] customer_emotion_counts = pd.Series(customer_emotions).value_counts() if not customer_emotion_counts.empty: fig, ax = plt.subplots(figsize=(6, 4)) customer_emotion_counts.plot.pie(autopct='%1.1f%%', ax=ax) ax.set_title('客户情感分布') ax.set_ylabel('') # 移除默认的ylabel plt.tight_layout() chart_path = os.path.join(self.output_dir, "customer_emotion_chart.png") plt.savefig(chart_path, dpi=100, bbox_inches='tight') plt.close() doc.add_picture(chart_path, width=Inches(4)) doc.add_paragraph('图2: 客户情感分布') # 添加详细分析结果 doc.add_heading('详细分析结果', level=1) # 创建详细表格 table = doc.add_table(rows=1, cols=6) table.style = 'Table Grid' # 表头 hdr_cells = table.rows[0].cells headers = ['文件名', '开场白', '结束语', '禁语', '客服情感', '问题解决'] for i, header in enumerate(headers): hdr_cells[i].text = header # 填充数据 for res in result['results']: row_cells = table.add_row().cells row_cells[0].text = res['file_name'] row_cells[1].text = "✓" if res['opening_check'] else "✗" row_cells[2].text = "✓" if res['closing_check'] else "✗" row_cells[3].text = "✗" if res['forbidden_check'] else "✓" row_cells[4].text = res['agent_emotion']['label'] row_cells[5].text = "✓" if res['resolution_rate'] else "✗" # 保存文档 doc.save(output_path) except Exception as e: raise Exception(f"生成Word报告失败: {str(e)}") if __name__ == "__main__": # 检查是否安装了torch try: import torch except ImportError: print("警告: PyTorch 未安装,情感分析可能无法使用GPU加速") app = QApplication(sys.argv) # 设置应用样式 app.setStyle("Fusion") window = MainWindow() window.show() sys.exit(app.exec_())
07-16
<think>嗯,用户想了解CUDA中的cuda::thread_scope::thread_scope_thread的具体作用和用法。首先,我需要回忆一下CUDA的内存模型和线程同步机制。记得CUDA C++扩展引入了内存同步和内存作用域的概念,thread_scope是用于指定原子操作或内存屏障的作用范围的。 thread_scope_thread应该是指该操作仅对当前线程有效,这意味着执行原子操作时,同步和可见性仅限于当前线程内部。其他线程,包括同一块内的或者全局的,都不会被影响。这可能在需要线程内部的自同步时使用,比如确保当前线程内的内存操作顺序,但不需要跨线程的同步。 接下来需要确认具体的应用场景。比如,当某个线程需要对自己内部的某些数据进行原子操作,而不需要与其他线程通信时,使用thread_scope_thread可以避免不必要的同步开销,提高性能。例如,线程内部的状态机或者计数器,只需保证本线程内的操作顺序和可见性。 然后,查阅相关的引用资料。引用[1]提到内核中每个线程有唯一的ID,引用[2]讨论异步内存拷贝中的对齐和大小,引用[3]涉及矩阵乘法的实现。但这些引用并没有直接提到thread_scope_thread,所以可能需要结合CUDA官方文档的知识。不过根据一般知识,内存作用域包括thread、block、device等,thread_scope_thread是最小范围。 需要解释清楚作用域的含义,并给出示例代码说明其用法。比如,在使用原子操作时指定作用域为thread_scope_thread,这样该原子操作的影响仅限于当前线程。这种情况下,其他线程不会感知到这个操作,适合线程内部的数据处理。 可能的应用场景还包括线程本地的缓存或临时变量,这些变量不需要跨线程同步。另外,在复杂的同步机制中,如果某个原子变量仅在单个线程内被访问,使用thread_scope_thread可以减少同步开销,提高效率。 最后,要确保回答结构清晰,分点说明作用和场景,并提供代码示例。同时生成相关问题,帮助用户进一步学习。</think>### CUDA `cuda::thread_scope::thread_scope_thread` 功能与使用场景 #### 1. **作用与定义** `cuda::thread_scope::thread_scope_thread` 是CUDA内存模型中的一个作用域标识符,用于指定**原子操作或内存屏障的可见性范围**。其作用域限定为**单个线程内部**,即操作仅对当前线程可见,不影响其他线程(包括同一线程块或全局范围内的线程)[^1]。 #### 2. **核心功能** - **线程内原子操作**:确保当前线程对共享内存或全局内存的原子操作仅在本线程内生效。 - **减少同步开销**:通过限制作用域,避免跨线程线程块的不必要同步,提升性能。 - **内存顺序控制**:结合内存屏障(如`cuda::atomic_thread_fence`),控制线程内指令执行的顺序。 #### 3. **典型使用场景** - **线程本地状态管理**:当线程需要独立维护自身的计数器或状态机时(例如统计线程内部事件)。 - **临时变量同步**:线程内部临时变量的原子操作,无需跨线程共享。 - **轻量级同步**:避免与其他线程竞争时,使用最小作用域减少性能损耗。 #### 4. **代码示例** ```cpp #include <cuda/atomic> __global__ void kernel() { // 声明一个线程作用域的原子变量 cuda::atomic<int, cuda::thread_scope_thread> local_counter(0); // 仅当前线程可见的原子操作 local_counter.fetch_add(1, cuda::memory_order_relaxed); } ``` 此示例中,`local_counter` 的修改仅对当前线程有效,其他线程无法感知其变化。 #### 5. **对比其他作用域** | 作用域类型 | 范围 | 典型应用 | |--------------------------|--------------------|------------------------------| | `thread_scope_thread` | 单个线程 | 线程本地状态管理 | | `thread_scope_block` | 线程块内所有线程 | 块内共享数据同步 | | `thread_scope_device` | 同一GPU设备内 | 全局内存操作同步 |
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值