import argparse
import csv
import os
import platform
import sys
from math import radians, cos, sin, sqrt
from pathlib import Path
import cv2
import numpy as np
import torch
FILE = Path(__file__).resolve()
ROOT = FILE.parents[0] # YOLOv5 root directory
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT)) # add ROOT to PATH
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
from ultralytics.utils.plotting import Annotator, colors, save_one_box
from models.common import DetectMultiBackend
from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadScreenshots, LoadStreams
from utils.general import (
LOGGER,
Profile,
check_file,
check_img_size,
check_imshow,
check_requirements,
colorstr,
increment_path,
non_max_suppression,
print_args,
scale_boxes,
strip_optimizer,
xyxy2xywh,
)
from utils.torch_utils import select_device, smart_inference_mode
# 全局变量:保存上一帧的Bullet信息 {id: (x, y)}
prev_bullets = {}
next_bullet_id = 1 # 下一个可用的Bullet ID
def generate_tilted_grid(uav_center, img_shape, grid_spacing=20, tilt_angle=8):
"""生成以UAV为中心、倾斜指定角度的等距网格线"""
h, w = img_shape
angle_rad = radians(tilt_angle)
cos_a, sin_a = cos(angle_rad), sin(angle_rad)
horizontal_lines = [] # 倾斜的水平线
vertical_lines = [] # 倾斜的垂直线(与水平线垂直)
# 生成水平线(沿y轴方向等距分布)
for dy in range(-h, h, grid_spacing):
x1, y1 = 0, uav_center[1] + dy
x2, y2 = w, uav_center[1] + dy
# 旋转坐标(使用修正后的方向)
x1_rot = uav_center[0] + (x1 - uav_center[0]) * cos_a + (y1 - uav_center[1]) * sin_a
y1_rot = uav_center[1] - (x1 - uav_center[0]) * sin_a + (y1 - uav_center[1]) * cos_a
x2_rot = uav_center[0] + (x2 - uav_center[0]) * cos_a + (y2 - uav_center[1]) * sin_a
y2_rot = uav_center[1] - (x2 - uav_center[0]) * sin_a + (y2 - uav_center[1]) * cos_a
horizontal_lines.append(((x1_rot, y1_rot), (x2_rot, y2_rot)))
# 生成垂直线(沿x轴方向等距分布)
for dx in range(-w, w, grid_spacing):
x1, y1 = uav_center[0] + dx, 0
x2, y2 = uav_center[0] + dx, h
# 旋转坐标(使用修正后的方向)
x1_rot = uav_center[0] + (x1 - uav_center[0]) * cos_a + (y1 - uav_center[1]) * sin_a
y1_rot = uav_center[1] - (x1 - uav_center[0]) * sin_a + (y1 - uav_center[1]) * cos_a
x2_rot = uav_center[0] + (x2 - uav_center[0]) * cos_a + (y2 - uav_center[1]) * sin_a
y2_rot = uav_center[1] - (x2 - uav_center[0]) * sin_a + (y2 - uav_center[1]) * cos_a
vertical_lines.append(((x1_rot, y1_rot), (x2_rot, y2_rot)))
return horizontal_lines, vertical_lines
def calculate_grid_offset(bullet_center, uav_center, grid_spacing=20, tilt_angle=8):
"""计算Bullet在倾斜网格中的横、纵向偏移格数"""
angle_rad = radians(tilt_angle)
cos_a, sin_a = cos(angle_rad), sin(angle_rad)
# 确保坐标是Python数值
if isinstance(bullet_center[0], torch.Tensor):
bullet_center = (bullet_center[0].item(), bullet_center[1].item())
if isinstance(uav_center[0], torch.Tensor):
uav_center = (uav_center[0].item(), uav_center[1].item())
# 计算偏移量
dx = bullet_center[0] - uav_center[0]
dy = bullet_center[1] - uav_center[1]
# 旋转计算(修正方向)
bx_rot = dx * cos_a + dy * sin_a
by_rot = -dx * sin_a + dy * cos_a
# 计算偏移格数
offset_x = bx_rot / grid_spacing
offset_y = by_rot / grid_spacing
return round(offset_x, 2), round(offset_y, 2)
def predict_next_position(prev_pos, tilt_angle=8, distance=0):
"""预测Bullet沿8°轨迹运动后的位置(修正方向)"""
angle_rad = radians(tilt_angle)
# 修正方向:x/y轴移动量取负
dx = -distance * cos(angle_rad)
dy = -distance * sin(angle_rad)
return (prev_pos[0] + dx, prev_pos[1] + dy)
def match_bullets(current_bullets, prev_bullets, tilt_angle=8, distance_threshold=50):
"""匹配当前帧与上一帧的Bullet(基于8°轨迹和距离阈值)"""
matched = {}
used_prev_ids = set()
# 对当前每个Bullet,寻找最佳匹配的上一帧Bullet
for curr_pos in current_bullets:
best_match_id = None
min_distance = float('inf')
# 检查上一帧的每个Bullet
for prev_id, prev_pos in prev_bullets.items():
if prev_id in used_prev_ids:
continue
# 预测上一帧Bullet沿8°轨迹运动到当前帧的位置
predicted_pos = predict_next_position(prev_pos, tilt_angle, distance_threshold)
# 计算当前Bullet位置与预测位置的距离
distance = sqrt(
(curr_pos[0] - predicted_pos[0]) ** 2 +
(curr_pos[1] - predicted_pos[1]) ** 2
)
# 找到距离最近且小于阈值的匹配
if distance < min_distance and distance < distance_threshold:
min_distance = distance
best_match_id = prev_id
if best_match_id is not None:
# 匹配成功,沿用旧ID
matched[best_match_id] = curr_pos
used_prev_ids.add(best_match_id)
else:
# 无匹配,分配新ID
global next_bullet_id
matched[next_bullet_id] = curr_pos
next_bullet_id += 1
return matched
def init_offset_csv(csv_path):
"""初始化偏移量CSV文件,写入表头(更新为矩形框参数)"""
with open(csv_path, mode="w", newline="", encoding="utf-8") as f:
writer = csv.writer(f)
writer.writerow([
"图像名称", "帧序号", "UAV中心点X", "UAV中心点Y",
"紫色矩形宽(px)", "紫色矩形高(px)", "紫色矩形坐标(x1,y1,x2,y2)", # 替换原“半径”字段
"Bullet全局ID", "Bullet编号", "Bullet中心点X", "Bullet中心点Y",
"横向偏移格数", "纵向偏移格数", "网格间距(px)", "倾斜角度(度)"
])
def write_offset_to_csv(csv_path, data):
"""将偏移数据写入CSV文件"""
with open(csv_path, mode="a", newline="", encoding="utf-8") as f:
writer = csv.writer(f)
writer.writerow(data)
def get_purple_rect(uav_center, rect_width, rect_height, img_shape):
"""
【核心修改】以UAV为中心生成紫色矩形框(可自定义宽高)
:param uav_center: UAV中心点 (x, y)
:param rect_width: 矩形宽度(像素)
:param rect_height: 矩形高度(像素)
:param img_shape: 图像尺寸 (h, w)
:return: 矩形坐标 (x1, y1, x2, y2)(确保不超出图像边界)
"""
h, w = img_shape
x_center, y_center = uav_center
# 计算矩形左上角(x1,y1)和右下角(x2,y2):中心向左右扩展“宽/2”,向上下扩展“高/2”
x1 = x_center - (rect_width / 2)
y1 = y_center - (rect_height / 2)
x2 = x_center + (rect_width / 2)
y2 = y_center + (rect_height / 2)
# 确保矩形不超出图像边界
x1 = max(0, min(x1, w))
y1 = max(0, min(y1, h))
x2 = max(0, min(x2, w))
y2 = max(0, min(y2, h))
return (x1, y1, x2, y2)
@smart_inference_mode()
def run(
weights=ROOT / "weights/yolov5s.pt", # model path or triton URL
source=ROOT / "data/images", # file/dir/URL/glob/screen/0(webcam)
data=ROOT / "data/coco128.yaml", # dataset.yaml path
imgsz=(640, 640), # inference size (height, width)
conf_thres=0.25, # confidence threshold
iou_thres=0.45, # NMS IOU threshold
max_det=1000, # maximum detections per image
device="", # cuda device, i.e. 0 or 0,1,2,3 or cpu
view_img=False, # show results
save_txt=False, # save results to *.txt
save_format=0, # 0 for YOLO, 1 for Pascal-VOC
save_csv=True, # 保存偏移量CSV
save_conf=False, # save confidences in --save-txt labels
save_crop=False, # save cropped prediction boxes
nosave=False, # 不保存图像
classes=None, # filter by class: --class 0, or --class 0 2 3
agnostic_nms=False, # class-agnostic NMS
augment=False, # augmented inference
visualize=False, # visualize features
update=False, # update all models
project=ROOT / "runs/detect", # 结果保存根目录
name="exp", # 结果子目录名称
exist_ok=False, # existing project/name ok, do not increment
line_thickness=3, # bounding box thickness (pixels)
hide_labels=False, # hide labels
hide_conf=False, # hide confidences
half=False, # use FP16 half-precision inference
dnn=False, # use OpenCV DNN for ONNX inference
vid_stride=1, # video frame-rate stride
grid_spacing=20, # 网格间距(像素)
tilt_angle=8, # 倾斜角度(度)
purple_rect_width=400, # 【新增】紫色矩形宽度(像素)
purple_rect_height=300, # 【新增】紫色矩形高度(像素)
distance_threshold=200, # Bullet追踪的距离阈值(像素)
):
global prev_bullets, next_bullet_id
prev_bullets = {} # 重置上一帧Bullet信息
next_bullet_id = 1 # 重置ID计数器
source = str(source)
save_img = not nosave and not source.endswith(".txt") # 控制是否保存图像
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
is_url = source.lower().startswith(("rtsp://", "rtmp://", "http://", "https://"))
webcam = source.isnumeric() or source.endswith(".streams") or (is_url and not is_file)
screenshot = source.lower().startswith("screen")
if is_url and is_file:
source = check_file(source) # download
# 目录设置
save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) # 结果目录
(save_dir / "labels" if save_txt else save_dir).mkdir(parents=True, exist_ok=True)
offset_csv_path = save_dir / "bullet_grid_offsets.csv" # 偏移量CSV文件路径
if save_csv:
init_offset_csv(offset_csv_path) # 初始化CSV(已更新表头)
# 加载模型
device = select_device(device)
model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half)
stride, names, pt = model.stride, model.names, model.pt
imgsz = check_img_size(imgsz, s=stride) # 检查图像尺寸
# 数据加载器
bs = 1 # batch_size
if webcam:
view_img = check_imshow(warn=True)
dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
bs = len(dataset)
elif screenshot:
dataset = LoadScreenshots(source, img_size=imgsz, stride=stride, auto=pt)
else:
dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
vid_path, vid_writer = [None] * bs, [None] * bs
# 运行推理
model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz)) # 模型预热
seen, windows, dt = 0, [], (Profile(device=device), Profile(device=device), Profile(device=device))
frame_count = 0 # 帧计数器,用于追踪视频帧序号
for path, im, im0s, vid_cap, s in dataset:
frame_count += 1 # 递增帧计数
with dt[0]:
im = torch.from_numpy(im).to(model.device)
im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
im /= 255 # 0 - 255 to 0.0 - 1.0
if len(im.shape) == 3:
im = im[None] # 扩展为批次维度
if model.xml and im.shape[0] > 1:
ims = torch.chunk(im, im.shape[0], 0)
# 推理
with dt[1]:
visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
if model.xml and im.shape[0] > 1:
pred = None
for image in ims:
if pred is None:
pred = model(image, augment=augment, visualize=visualize).unsqueeze(0)
else:
pred = torch.cat((pred, model(image, augment=augment, visualize=visualize).unsqueeze(0)), dim=0)
pred = [pred, None]
else:
pred = model(im, augment=augment, visualize=visualize)
# NMS
with dt[2]:
pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
# 处理预测结果
for i, det in enumerate(pred): # 每张图像
seen += 1
if webcam: # 批量处理
p, im0, frame = path[i], im0s[i].copy(), dataset.count
s += f"{i}: "
else:
p, im0, frame = path, im0s.copy(), getattr(dataset, "frame", 0)
p = Path(p) # 路径处理
save_path = str(save_dir / p.name) # 可视化结果保存路径
txt_path = str(save_dir / "labels" / p.stem) + ("" if dataset.mode == "image" else f"_{frame}")
s += "{:g}x{:g} ".format(*im.shape[2:]) # 打印信息
# 初始化标注器
annotator = Annotator(im0, line_width=line_thickness, example=str(names))
# 分离UAV和Bullet
uav_boxes = [] # 存储UAV的 (xyxy, conf)
bullet_boxes = [] # 存储Bullet的 (xyxy, conf)
purple_rect_coords = None # 紫色矩形框坐标(替换原正方形框)
current_bullet_positions = [] # 当前帧Bullet中心点坐标列表
if len(det):
# 调整框到原始图像尺寸
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()
# 分类UAV和Bullet
for *xyxy, conf, cls in reversed(det):
c = int(cls)
label = names[c]
if label == "UAV": # 假设类别名为"UAV"
uav_boxes.append((xyxy, float(conf)))
elif label == "bullet": # 假设类别名为"bullet"
bullet_boxes.append((xyxy, float(conf)))
# 处理UAV和生成紫色矩形框(核心修改:正方形→矩形)
if uav_boxes:
# 取第一个UAV作为基准
uav_xyxy, uav_conf = uav_boxes[0]
# 计算UAV中心点并转换为数值
uav_center = (
(uav_xyxy[0] + uav_xyxy[2]) / 2, # UAV中心X
(uav_xyxy[1] + uav_xyxy[3]) / 2 # UAV中心Y
)
uav_center = (
uav_center[0].item() if isinstance(uav_center[0], torch.Tensor) else uav_center[0],
uav_center[1].item() if isinstance(uav_center[1], torch.Tensor) else uav_center[1]
)
# 生成紫色矩形框(以UAV为中心,自定义宽高)
img_h, img_w = im0.shape[:2]
purple_rect_coords = get_purple_rect(
uav_center, purple_rect_width, purple_rect_height, (img_h, img_w)
)
x1, y1, x2, y2 = purple_rect_coords
# 标注矩形框(显示宽高信息)
annotator.box_label(
[x1, y1, x2, y2],
f"Purple Rect (W={purple_rect_width}, H={purple_rect_height})",
color=(128, 0, 128) # 保持紫色
)
# 筛选紫色矩形框内的Bullet并收集中心点
filtered_bullets = []
for bullet_xyxy, bullet_conf in bullet_boxes:
bx1, by1, bx2, by2 = bullet_xyxy
# 计算Bullet中心点并转换为数值
bullet_center = (
(bx1 + bx2) / 2,
(by1 + by2) / 2
)
bullet_center = (
bullet_center[0].item() if isinstance(bullet_center[0], torch.Tensor) else bullet_center[0],
bullet_center[1].item() if isinstance(bullet_center[1], torch.Tensor) else bullet_center[1]
)
# 判断是否在紫色矩形框内
if x1 <= bullet_center[0] <= x2 and y1 <= bullet_center[1] <= y2:
filtered_bullets.append((bullet_xyxy, bullet_conf, bullet_center))
current_bullet_positions.append(bullet_center) # 收集中心点用于匹配
bullet_boxes = filtered_bullets
# 生成并绘制倾斜网格(保持不变)
horizontal_lines, vertical_lines = generate_tilted_grid(
uav_center, (img_h, img_w), grid_spacing, tilt_angle
)
for line in horizontal_lines:
cv2.line(im0,
(int(line[0][0]), int(line[0][1])),
(int(line[1][0]), int(line[1][1])),
(0, 0, 255), 1) # 红色水平线
for line in vertical_lines:
cv2.line(im0,
(int(line[0][0]), int(line[0][1])),
(int(line[1][0]), int(line[1][1])),
(0, 255, 0), 1) # 绿色垂直线
# 匹配当前帧与上一帧的Bullet(核心追踪逻辑,保持不变)
matched_bullets = match_bullets(
current_bullet_positions,
prev_bullets,
tilt_angle=tilt_angle,
distance_threshold=distance_threshold
)
# 更新上一帧Bullet信息为当前帧
prev_bullets = matched_bullets.copy()
# 绘制Bullet并标注偏移和全局ID(保持不变)
for bullet_idx, (bullet_xyxy, bullet_conf, bullet_center) in enumerate(bullet_boxes, 1):
# 找到当前Bullet对应的全局ID
bullet_id = None
for bid, pos in matched_bullets.items():
if (round(pos[0], 1) == round(bullet_center[0], 1) and
round(pos[1], 1) == round(bullet_center[1], 1)):
bullet_id = bid
break
# 计算偏移
offset_x, offset_y = calculate_grid_offset(
bullet_center, uav_center, grid_spacing, tilt_angle
)
# 标注Bullet信息(包含全局ID)
label = f"Bullet {bullet_id} ({offset_x}, {offset_y})"
annotator.box_label(bullet_xyxy, label, color=colors(c, True))
# 保存到CSV(更新为矩形框参数)
if save_csv:
purple_rect_str = f"({x1:.1f},{y1:.1f},{x2:.1f},{y2:.1f})"
write_offset_to_csv(offset_csv_path, [
p.name, # 图像名称
frame_count, # 帧序号
round(uav_center[0], 1), # UAV中心X
round(uav_center[1], 1), # UAV中心Y
purple_rect_width, # 紫色矩形宽
purple_rect_height, # 紫色矩形高
purple_rect_str, # 紫色矩形坐标
bullet_id, # Bullet全局ID(跨帧唯一)
bullet_idx, # 本帧内的编号
round(bullet_center[0], 1), # Bullet中心X
round(bullet_center[1], 1), # Bullet中心Y
offset_x, # 横向偏移格数
offset_y, # 纵向偏移格数
grid_spacing, # 网格间距
tilt_angle # 倾斜角度
])
# 处理无UAV的情况(更新CSV字段)
elif bullet_boxes and save_csv:
write_offset_to_csv(offset_csv_path, [
p.name, frame_count, "", "",
purple_rect_width, purple_rect_height, "无UAV,无法生成", # 矩形参数
"", "", "", "", "", "", grid_spacing, tilt_angle
])
LOGGER.warning(f"图像 {p.name} 未检测到UAV,无法生成紫色矩形框和计算偏移量")
# 显示结果(保持不变)
im0 = annotator.result()
if view_img:
if platform.system() == "Linux" and p not in windows:
windows.append(p)
cv2.namedWindow(str(p), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO)
cv2.resizeWindow(str(p), im0.shape[1], im0.shape[0])
cv2.imshow(str(p), im0)
cv2.waitKey(1) # 1毫秒
# 保存可视化结果(保持不变)
if save_img:
if dataset.mode == "image":
cv2.imwrite(save_path, im0)
LOGGER.info(f"可视化结果已保存至: {save_path}")
else: # 视频
if vid_path[i] != save_path:
vid_path[i] = save_path
if isinstance(vid_writer[i], cv2.VideoWriter):
vid_writer[i].release()
if vid_cap:
fps = vid_cap.get(cv2.CAP_PROP_FPS)
w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
else:
fps, w, h = 30, im0.shape[1], im0.shape[0]
save_path = str(Path(save_path).with_suffix(".mp4"))
vid_writer[i] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h))
vid_writer[i].write(im0)
# 打印时间信息(保持不变)
LOGGER.info(f"{s}{'' if len(det) else '(no detections), '}{dt[1].dt * 1E3:.1f}ms")
# 输出结果总结(保持不变)
t = tuple(x.t / seen * 1e3 for x in dt)
LOGGER.info(f"Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {(1, 3, *imgsz)}" % t)
if save_csv:
LOGGER.info(f"偏移数据已保存至: {offset_csv_path}")
if save_img:
LOGGER.info(f"所有可视化结果已保存至: {save_dir}")
if update:
strip_optimizer(weights[0])
def parse_opt():
parser = argparse.ArgumentParser()
parser.add_argument("--weights", nargs="+", type=str,
default='/home/wsj/下载/yolov5-master/yolov5-master/best.pt', help="模型路径")
parser.add_argument("--source", type=str, default=ROOT / "data/images", help="输入源")
parser.add_argument("--data", type=str, default=ROOT / "data/coco128.yaml", help="数据集配置")
parser.add_argument("--imgsz", "--img", "--img-size", nargs="+", type=int, default=[640], help="推理尺寸")
parser.add_argument("--conf-thres", type=float, default=0.25, help="置信度阈值")
parser.add_argument("--iou-thres", type=float, default=0.45, help="NMS阈值")
parser.add_argument("--max-det", type=int, default=1000, help="最大检测数")
parser.add_argument("--device", default="", help="设备(cpu/cuda)")
parser.add_argument("--view-img", action="store_true", help="实时显示结果")
parser.add_argument("--save-txt", action="store_true", help="保存txt标签")
parser.add_argument("--save-format", type=int, default=0, help="标签格式")
parser.add_argument("--save-csv", action="store_true", default=True, help="保存偏移CSV")
parser.add_argument("--save-conf", action="store_true", help="保存置信度")
parser.add_argument("--save-crop", action="store_true", help="保存裁剪图像")
parser.add_argument("--nosave", action="store_true", default=False, help="不保存任何结果")
parser.add_argument("--classes", nargs="+", type=int, help="过滤类别")
parser.add_argument("--agnostic-nms", action="store_true", help="类别无关NMS")
parser.add_argument("--augment", action="store_true", help="增强推理")
parser.add_argument("--visualize", action="store_true", help="可视化特征")
parser.add_argument("--update", action="store_true", help="更新模型")
parser.add_argument("--project", default=ROOT / "runs/detect", help="结果根目录")
parser.add_argument("--name", default="exp", help="结果子目录")
parser.add_argument("--exist-ok", action="store_true", help="允许覆盖目录")
parser.add_argument("--line-thickness", default=3, type=int, help="框线厚度")
parser.add_argument("--hide-labels", default=False, action="store_true", help="隐藏标签")
parser.add_argument("--hide-conf", default=False, action="store_true", help="隐藏置信度")
parser.add_argument("--half", action="store_true", help="半精度推理")
parser.add_argument("--dnn", action="store_true", help="使用DNN")
parser.add_argument("--vid-stride", type=int, default=1, help="视频步长")
parser.add_argument("--grid-spacing", type=int, default=50, help="网格间距")
parser.add_argument("--tilt-angle", type=int, default=8, help="倾斜角度(度)")
parser.add_argument("--purple-rect-width", type=int, default=300, help="紫色矩形框宽度(像素)")
parser.add_argument("--purple-rect-height", type=int, default=100, help="紫色矩形框高度(像素)")
parser.add_argument("--distance-threshold", type=int, default=30, help="Bullet追踪的距离阈值(像素)")
opt = parser.parse_args()
opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # 扩展尺寸
print_args(vars(opt))
return opt
def main(opt):
check_requirements(ROOT / "requirements.txt", exclude=("tensorboard", "thop"))
run(**vars(opt))
if __name__ == "__main__":
opt = parse_opt()
main(opt)怎么输出和输入图片一样