import cv2
import numpy as np
import torch
import argparse
from pathlib import Path
from yolov5.models.experimental import attempt_load
from yolov5.utils.general import non_max_suppression
def detect_crosswalk(image):
"""检测斑马线区域并返回中心线x坐标"""
hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
lower_white = np.array([0, 0, 200])
upper_white = np.array([180, 30, 255])
mask = cv2.inRange(hsv, lower_white, upper_white)
kernel = np.ones((10, 10), np.uint8)
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel, iterations=2)
cnts = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
contours = cnts[0] if len(cnts) == 2 else cnts[1]
max_area = 0
crosswalk_rect = None
for cnt in contours:
x, y, w, h = cv2.boundingRect(cnt)
aspect_ratio = w / h
area = w * h
if aspect_ratio > 2 and area > 5000:
if area > max_area:
max_area = area
crosswalk_rect = (x, y, w, h) # 存储完整坐标信息
# 修正后的返回逻辑
if crosswalk_rect:
x, _, w, _ = crosswalk_rect # 从存储的元组中解包
return x + w // 2
else:
return image.shape[1] // 2
def process_frame(frame, model, device):
"""处理单帧的核心逻辑"""
# YOLOv5检测
img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
img_tensor = torch.from_numpy(img_rgb).permute(2, 0, 1).float().div(255)
img_tensor = img_tensor.unsqueeze(0).to(device)
with torch.no_grad():
pred = model(img_tensor)[0]
pred = non_max_suppression(pred, conf_thres=0.25, iou_thres=0.45)
# 斑马线检测
center_x = detect_crosswalk(frame)
cv2.line(frame, (center_x, 0), (center_x, frame.shape[0]), (0, 255, 255), 2)
# 处理检测结果
if pred and pred[0] is not None:
for det in pred[0]:
x1, y1, x2, y2, conf, cls = det.cpu().numpy()
if int(cls) not in [0, 2, 5, 7]: continue
x1, y1, x2, y2 = map(int, [x1, y1, x2, y2])
obj_center = (x1 + x2) // 2
position = "左侧" if obj_center < center_x else "右侧"
color = (0, 255, 0) if position == "左侧" else (0, 0, 255)
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
cv2.putText(frame, f'{position}', (x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
return frame
def main(input_path='test.jpg'):
# 设备初始化
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 加载模型
model_path = Path('yolov5s.pt')
if not model_path.exists():
raise FileNotFoundError(f"Model file {model_path} not found")
model = attempt_load(model_path).to(device)
# 视频处理模式
if input_path.isdigit() or Path(input_path).suffix in ['.mp4', '.avi']:
cap = cv2.VideoCapture(int(input_path) if input_path.isdigit() else input_path)
while cap.isOpened():
ret, frame = cap.read()
if not ret: break
processed = process_frame(frame, model, device)
cv2.imshow('Detection', processed)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
# 图片处理模式
else:
frame = cv2.imread(input_path)
if frame is None:
raise ValueError(f"Cannot read image from {input_path}")
processed = process_frame(frame, model, device)
cv2.imshow('Result', processed)
cv2.waitKey(0)
cv2.destroyAllWindows()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--input', type=str, default='test.jpg',
help='输入路径 (图片/视频/摄像头ID)')
args = parser.parse_args()
main(args.input) 以上是我写的一段代码,请你分析这段代码有什么作用并进行改进