import argparse
import os
import platform
import sys
import time
from pathlib import Path
import torch
import cv2
import numpy as np
import yaml
import serial
import traceback
# 获取当前文件的绝对路径并解析根目录
FILE = Path(__file__).resolve()
ROOT = FILE.parents[0]
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT))
ROOT = Path(os.path.relpath(ROOT, Path.cwd()))
# 导入必要的模块
from models.common import DetectMultiBackend
from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadScreenshots, LoadStreams
from utils.general import (
LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr,
increment_path, non_max_suppression, print_args, scale_boxes, strip_optimizer, xyxy2xywh
)
from utils.plots import colors, save_one_box
from utils.torch_utils import select_device, smart_inference_mode
class Annotator:
def __init__(self, im, line_width=None):
# 初始化图像和线条宽度
self.im = im
self.lw = line_width or max(round(sum(im.shape) / 2 * 0.003), 2)
def box_label(self, box, label='', color=(128, 128, 128)):
# 在图像上绘制边界框和标签
p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
cv2.rectangle(self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA)
if label:
tf = max(self.lw - 1, 1)
w, h = cv2.getTextSize(label, 0, fontScale=self.lw / 3, thickness=tf)[0]
outside = p1[1] - h >= 3
p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA)
cv2.putText(self.im, label, (p1[0], p1[1] - 2 if outside else p1[1] + h + 2), 0, self.lw / 3,
(255, 255, 255),
thickness=tf, lineType=cv2.LINE_AA)
return self.im
def result(self):
# 返回标注后的图像
return self.im
def read_config(config_path):
try:
# 读取配置文件
with open(config_path, 'r') as f:
return yaml.safe_load(f)
except FileNotFoundError:
LOGGER.warning(f"未找到配置文件 '{config_path}',将使用默认值。")
return {}
def setup_serial(port, baudrate=115200):
"""直接尝试打开指定串口端口"""
if not port:
LOGGER.warning("未指定串口端口,跳过串口初始化")
return None
try:
# 尝试打开串口
ser = serial.Serial(port, baudrate, timeout=1)
LOGGER.info(f"成功打开串口 {port}")
return ser
except serial.SerialException as e:
LOGGER.error(f"无法打开串口 {port}: {str(e)}")
return None
except Exception as e:
LOGGER.error(f"串口初始化异常: {str(e)}")
return None
# 高精度测距类
class EnhancedDistanceCalculator:
def __init__(self, focal_length, real_width):
self.focal_length = focal_length
self.real_width = real_width
self.distance_history = []
self.stable_count = 0
self.stable_distance = 0.0
def calculate_distance(self, pixel_width, pixel_height):
"""改进的距离计算方法,考虑边界框的宽高比"""
# 计算宽高比
aspect_ratio = pixel_width / pixel_height if pixel_height > 0 else 1.0
# 计算基础距离
base_distance = (self.real_width * self.focal_length) / pixel_width
# 根据宽高比调整距离
# 当宽高比接近1时(球体正对相机),不需要调整
# 当宽高比偏离1时(球体倾斜),需要校正
aspect_factor = min(1.0, max(0.5, aspect_ratio)) # 限制在0.5-1.0之间
corrected_distance = base_distance * (1.0 + (1.0 - aspect_factor) * 0.3)
# 使用历史数据平滑距离值
self.distance_history.append(corrected_distance)
if len(self.distance_history) > 5:
self.distance_history.pop(0)
# 计算平均距离
avg_distance = sum(self.distance_history) / len(self.distance_history)
# 稳定性检测
if abs(avg_distance - self.stable_distance) < 0.2: # 小于20cm变化认为是稳定的
self.stable_count += 1
else:
self.stable_count = 0
self.stable_distance = avg_distance
# 当连续稳定5帧以上才使用当前距离
if self.stable_count > 5:
return self.stable_distance
return avg_distance
def calculate_angle(cx, cy, width, height):
"""计算物体中心点相对于图像中心的角度"""
origin_x, origin_y = width // 2, height // 2
dx, dy = cx - origin_x, origin_y - cy
# 计算角度(弧度)
angle_rad = np.arctan2(dy, dx)
# 转换为角度(0-360度)
angle_deg = np.degrees(angle_rad)
if angle_deg < 0:
angle_deg += 360
return angle_deg
@smart_inference_mode()
def run(
weights=ROOT / 'combined_model.pt', # 使用合并后的模型
source=ROOT / 'data/images',
data=ROOT / 'data/coco128.yaml',
ball_diameter=9,
imgsz=(1920, 1080),
conf_thres=0.25,
iou_thres=0.45,
max_det=1000,
device='CPU',
view_img=True,
save_txt=False,
save_conf=False,
save_crop=False,
nosave=False,
classes=None,
agnostic_nms=False,
augment=False,
visualize=False,
update=False,
project=ROOT / 'runs/detect',
name='exp',
exist_ok=False,
line_thickness=3,
hide_labels=False,
hide_conf=False,
half=False,
dnn=False,
vid_stride=1,
config_path='config.yaml',
known_width=0.2, # 校准物体实际宽度
known_distance=2.0, # 校准物体已知距离
ref_pixel_width=100, # 校准物体像素宽度
ball_real_width=0.2, # 篮球实际宽度
hoop_real_width=1.0, # 篮筐实际宽度
serial_port=None, # 指定串口端口
serial_baud=115200, # 串口波特率
serial_interval=3, # 串口发送间隔(帧数)
):
# 处理输入源
source = str(source)
save_img = not nosave and not source.endswith('.txt')
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://'))
webcam = source.isnumeric() or source.endswith('.txt') or (is_url and not is_file)
screenshot = source.lower().startswith('screen')
if is_url and is_file:
source = check_file(source)
# 创建保存目录
save_dir = increment_path(Path(project) / name, exist_ok=exist_ok)
if save_txt:
(save_dir / 'labels').mkdir(parents=True, exist_ok=True)
else:
save_dir.mkdir(parents=True, exist_ok=True)
# 选择设备并加载模型
device = select_device(device)
model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half)
stride, names, pt = model.stride, model.names, model.pt
imgsz = check_img_size(imgsz, s=stride)
# 计算焦距 (必须放在数据集初始化前)
if known_width <= 0 or known_distance <= 0 or ref_pixel_width <= 0:
raise ValueError("[ERROR] Calibration parameters must be positive values!")
focal_length = (ref_pixel_width * known_distance) / known_width
print(f"[Calibration] Focal Length = {focal_length:.2f} px·m")
# ========== 焦距计算结束 ==========
# 初始化高精度测距器
ball_distance_calculator = EnhancedDistanceCalculator(focal_length, ball_real_width)
hoop_distance_calculator = EnhancedDistanceCalculator(focal_length, hoop_real_width)
# 初始化数据集 - 增加视频流稳定性处理
bs = 1
retry_count = 0
max_retries = 5
dataset = None
while retry_count < max_retries:
try:
if webcam:
view_img = check_imshow(warn=True)
dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
bs = len(dataset)
break # 如果成功则跳出循环
elif screenshot:
dataset = LoadScreenshots(source, img_size=imgsz, stride=stride, auto=pt)
bs = 1
break
else:
dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
bs = 1
break
except Exception as e:
retry_count += 1
LOGGER.warning(f"视频流初始化失败,尝试重新连接 ({retry_count}/{max_retries})...")
time.sleep(2) # 等待2秒后重试
if dataset is None:
LOGGER.error(f"无法初始化视频流,请检查输入源: {source}")
return
vid_path, vid_writer = [None] * bs, [None] * bs
# 打开串口 - 直接使用指定端口
ser = setup_serial(serial_port, serial_baud) if serial_port else None
if ser and not ser.is_open:
LOGGER.warning("串口未成功打开,继续无串口模式运行")
ser = None
elif ser:
LOGGER.info(f"串口连接已建立: {serial_port}@{serial_baud}bps")
# 模型预热
model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz))
seen, windows, dt = 0, [], (Profile(), Profile(), Profile())
# 存储检测到的篮球和篮筐信息
detected_basketball = None
detected_hoop = None
# 帧计数器用于控制串口发送频率
frame_counter = 0
for path, im, im0s, vid_cap, s in dataset:
frame_counter += 1
with dt[0]:
# 预处理图像
im = torch.from_numpy(im).to(model.device).half() if model.fp16 else torch.from_numpy(im).to(
model.device).float()
im /= 255
if len(im.shape) == 3:
im = im[None]
with dt[1]:
# 模型推理
visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
pred = model(im, augment=augment, visualize=visualize)
with dt[2]:
# 非极大值抑制
# 确保结果在CPU上以便后续处理
pred = [x.cpu() for x in non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)]
# 重置检测信息
detected_basketball = None
detected_hoop = None
for i, det in enumerate(pred):
seen += 1
p, im0, frame = (path[i], im0s[i].copy(), dataset.count) if webcam else (
path, im0s.copy(), getattr(dataset, 'frame', 0))
p = Path(p)
save_path = str(save_dir / p.name)
txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}')
s += '%gx%g' % im.shape[2:]
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]]
imc = im0.copy() if save_crop else im0
annotator = Annotator(im0, line_width=line_thickness)
# 获取图像尺寸
height, width = im0.shape[:2]
if len(det):
# 确保在CPU上
det = det.cpu()
# 调整检测框坐标
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()
s += ', '.join([f"{(det[:, 5] == c).sum()} {names[int(c)]}{'s' * ((det[:, 5] == c).sum() > 1)}" for c in
det[:, 5].unique()])
# 遍历所有检测结果
for *xyxy, conf, cls in reversed(det):
try:
# 确保坐标是普通数值
xyxy = [x.item() for x in xyxy]
# 计算像素宽度和高度
pixel_width = xyxy[2] - xyxy[0]
pixel_height = xyxy[3] - xyxy[1]
# 获取类别信息
class_id = int(cls)
class_name = names[class_id]
# 根据类别选择测距器和参数
if class_name == 'basketball': # 篮球
distance = ball_distance_calculator.calculate_distance(pixel_width, pixel_height)
# 计算篮球中心位置
cx = (xyxy[0] + xyxy[2]) / 2
cy = (xyxy[1] + xyxy[3]) / 2
# 计算角度信息
angle_deg = calculate_angle(cx, cy, width, height)
# 更新篮球信息
if detected_basketball is None or distance < detected_basketball['distance']:
detected_basketball = {
'distance': distance,
'angle_deg': angle_deg,
'cx': cx,
'cy': cy
}
# 创建标签
label = f'篮球 {conf:.2f} {distance:.2f}m {angle_deg:.1f}°'
# 在图像上标记篮球中心点
cv2.circle(im0, (int(cx), int(cy)), 5, (0, 0, 255), -1)
elif class_name == 'hoop': # 篮筐
distance = hoop_distance_calculator.calculate_distance(pixel_width, pixel_height)
# 计算篮筐中心位置
cx = (xyxy[0] + xyxy[2]) / 2
cy = (xyxy[1] + xyxy[3]) / 2
# 计算角度信息
angle_deg = calculate_angle(cx, cy, width, height)
# 更新篮筐信息
if detected_hoop is None or distance < detected_hoop['distance']:
detected_hoop = {
'distance': distance,
'angle_deg': angle_deg,
'cx': cx,
'cy': cy
}
# 创建标签
label = f'篮筐 {conf:.2f} {distance:.2f}m {angle_deg:.1f}°'
# 在图像上标记篮筐中心点
cv2.circle(im0, (int(cx), int(cy)), 5, (0, 255, 0), -1)
else: # 其他类别
distance = ball_distance_calculator.calculate_distance(pixel_width, pixel_height)
cx = (xyxy[0] + xyxy[2]) / 2
cy = (xyxy[1] + xyxy[3]) / 2
angle_deg = calculate_angle(cx, cy, width, height)
label = f'{class_name} {conf:.2f} {distance:.2f}m'
# 绘制边界框
im0 = annotator.box_label(xyxy, label=label, color=colors(class_id, True))
# 在终端打印信息
print(f"检测到的 {class_name} - 距离: {distance:.2f}m, 角度: {angle_deg:.1f}°")
except Exception as e:
print(f"[ERROR] 处理检测结果失败: {e}")
traceback.print_exc()
# 串口通信 - 按照新数据格式发送
if frame_counter % serial_interval == 0:
if ser and ser.is_open:
try:
# 发送篮球数据 (格式: "1,{angle:.2f},{distance:.2f}\n")
if detected_basketball:
ball_message = f"1,{detected_basketball['angle_deg']:.2f},{detected_basketball['distance']:.2f}\n"
ser.write(ball_message.encode('ascii'))
LOGGER.debug(f"📤 发送篮球数据: {ball_message.strip()}")
# 发送篮筐数据 (格式: "0,{angle:.2f},{distance:.2f}\n")
if detected_hoop:
hoop_message = f"0,{detected_hoop['angle_deg']:.2f},{detected_hoop['distance']:.2f}\n"
ser.write(hoop_message.encode('ascii'))
LOGGER.debug(f"📤 发送篮筐数据: {hoop_message.strip()}")
# 如果没有检测到任何物体,发送空行保持连接
if not detected_basketball and not detected_hoop:
ser.write(b'\n') # 发送空行
LOGGER.debug("📤 发送空行")
except serial.SerialException as e:
LOGGER.error(f"🚨 串口发送失败: {str(e)}")
ser.close()
ser = None # 标记串口失效
# 在图像上绘制中心点
origin_x, origin_y = width // 2, height // 2
cv2.circle(im0, (origin_x, origin_y), 10, (255, 0, 0), -1)
# 在图像上绘制十字线
cv2.line(im0, (0, origin_y), (width, origin_y), (0, 255, 0), 2)
cv2.line(im0, (origin_x, 0), (origin_x, height), (0, 255, 0), 2)
# 显示检测信息
info_y = 30
if detected_basketball:
info_text = f"篮球: {detected_basketball['distance']:.2f}m, {detected_basketball['angle_deg']:.1f}°"
cv2.putText(im0, info_text, (10, info_y), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
info_y += 40
if detected_hoop:
info_text = f"篮筐: {detected_hoop['distance']:.2f}m, {detected_hoop['angle_deg']:.1f}°"
cv2.putText(im0, info_text, (10, info_y), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
info_y += 40
if not detected_basketball and not detected_hoop:
info_text = "未检测到目标"
cv2.putText(im0, info_text, (10, info_y), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
im0 = annotator.result()
if view_img:
# 显示图像
if platform.system() == 'Linux' and p not in windows:
windows.append(p)
cv2.namedWindow(str(p), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO)
cv2.resizeWindow(str(p), im0.shape[1], im0.shape[0])
cv2.imshow(str(p), im0)
if cv2.waitKey(1) == ord('q'): # 按q退出
break
if save_img:
# 保存图像或视频
if dataset.mode == 'image':
cv2.imwrite(save_path, im0)
else:
if vid_path[i] != save_path:
if isinstance(vid_writer[i], cv2.VideoWriter):
vid_writer[i].release()
fps, w, h = (vid_cap.get(cv2.CAP_PROP_FPS), int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))) if vid_cap else (
30, im0.shape[1], im0.shape[0])
save_path = str(Path(save_path).with_suffix('.mp4'))
vid_writer[i] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
vid_writer[i].write(im0)
LOGGER.info(f"{s}{'' if len(det) else '(no detections), '}{dt[1].dt * 1E3:.1f}ms")
t = tuple(x.t / seen * 1E3 for x in dt)
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {(1, 3, *imgsz)}' % t)
if save_txt or save_img:
s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}{s}")
if update:
strip_optimizer(weights[0])
if ser and ser.is_open:
ser.close()
LOGGER.info("🔌 串口连接已关闭")
def parse_opt():
# 解析命令行参数
parser = argparse.ArgumentParser()
parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'runs/train/exp33/weights/best.pt',
help='模型路径或triton URL')
parser.add_argument('--source', type=str, default='2', help='文件/目录/URL/glob/屏幕/0(摄像头)')
parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='(可选)数据集yaml路径')
parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='推理尺寸h,w')
parser.add_argument('--conf-thres', type=float, default=0.25, help='置信度阈值')
parser.add_argument('--iou-thres', type=float, default=0.45, help='NMS IoU阈值')
parser.add_argument('--max-det', type=int, default=1000, help='每张图像最大检测数')
parser.add_argument('--device', default='cuda', help='cuda设备, 例如 0 或 0,1,2,3 或 cpu')
parser.add_argument('--view-img', action='store_true', help='显示结果')
parser.add_argument('--save-txt', action='store_true', help='保存结果到*.txt')
parser.add_argument('--save-conf', action='store_true', help='在--save-txt标签中保存置信度')
parser.add_argument('--save-crop', action='store_true', help='保存裁剪的预测框')
parser.add_argument('--nosave', action='store_true', help='不保存图像/视频')
parser.add_argument('--classes', nargs='+', type=int, help='按类别过滤: --classes 0, 或 --classes 0 2 3')
parser.add_argument('--agnostic-nms', action='store_true', help='类别不可知的NMS')
parser.add_argument('--augment', action='store_true', help='增强推理')
parser.add_argument('--visualize', action='store_true', help='可视化特征')
parser.add_argument('--update', action='store_true', help='更新所有模型')
parser.add_argument('--project', default=ROOT / 'runs/detect', help='保存结果到项目/名称')
parser.add_argument('--name', default='exp', help='保存结果到项目/名称')
parser.add_argument('--exist-ok', action='store_true', help='现有项目/名称可用,不增加')
parser.add_argument('--line-thickness', default=3, type=int, help='边界框厚度(像素)')
parser.add_argument('--hide-labels', default=False, action='store_true', help='隐藏标签')
parser.add_argument('--hide-conf', default=False, action='store_true', help='隐藏置信度')
parser.add_argument('--half', action='store_true', help='使用FP16半精度推理')
parser.add_argument('--dnn', action='store_true', help='使用OpenCV DNN进行ONNX推理')
parser.add_argument('--vid-stride', type=int, default=1, help='视频帧步长')
parser.add_argument('--known-width', type=float, default=0.2,
help='参考物体的已知宽度(米)')
parser.add_argument('--known-distance', type=float, default=2.0,
help='参考物体的已知距离(米)')
parser.add_argument('--ref-pixel-width', type=float, default=100,
help='图像中参考物体的像素宽度')
parser.add_argument('--ball-real-width', type=float, default=0.2,
help='篮球的实际宽度(米)')
parser.add_argument('--hoop-real-width', type=float, default=1.0,
help='篮筐的实际宽度(米)')
# 串口参数
parser.add_argument('--serial-port', type=str, default=None,
help='指定串口端口 (例如 COM3 或 /dev/ttyUSB0)')
parser.add_argument('--serial-baud', type=int, default=115200,
help='串口波特率 (默认: 115200)')
parser.add_argument('--serial-interval', type=int, default=3,
help='串口发送间隔 (帧数, 默认: 每3帧发送一次)')
# ========== 参数添加结束 ==========
opt = parser.parse_args()
opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1
print_args(vars(opt))
return opt
def main(opt):
# 主函数,调用run函数
run(**vars(opt))
if __name__ == "__main__":
# 解析命令行参数并运行主函数
opt = parse_opt()
main(opt) 检测到的 ball - 距离: 5.35m, 角度: 183. 99 16:01:42.188 0.00010520 python.exe IRP_MJ_WRITE COM7 SUCCESS Length: 1, Data: 0A
这些是我使用串口时的数据 很明显串口的发送有问题请你帮我改进,请不要增加代码的运算负担,增加不必要的功能