opencv多目标追踪容器
之前做过一个多目标追踪的项目,尝试了一下opencv提供的追踪容器,个人感觉效果一般。
# coding:utf-8
# @Time : 14/12/2018 17:07
# @Author : SuRui
import cv2
from functools import wraps
def GetFps(func):
@wraps(func)
def wrapper(*args):
timer = cv2.getTickCount()
boxes = func(*args)
fps = cv2.getTickFrequency() / (cv2.getTickCount() - timer)
return boxes, fps
return wrapper
class MultipleTargetTracker:
def __init__(self):
self.multi_tracker = cv2.MultiTracker_create()
self._method_name = None
@property
def method_name(self):
return self._method_name
@method_name.setter
def method_name(self, method):
self._method_name = method
@property
def tracker_method(self):
"""
Choose one method from the follow list.
'BOOSTING', 'CSRT', 'MIL', 'KCF', 'TLD', 'MEDIANFLOW', 'MOSSE', 'GOTURN'
:param method_name:
:return:
"""
# (major_ver, minor_ver, subminor_ver) = (cv2.__version__).split('.')
tracker_types = ['BOOSTING', 'CSRT', 'MIL', 'KCF', 'TLD', 'MEDIANFLOW', 'MOSSE', 'GOTURN']
tracker_generator = [cv2.TrackerBoosting_create, cv2.TrackerCSRT_create, cv2.TrackerMIL_create,
cv2.TrackerKCF_create, cv2.TrackerTLD_create, cv2.TrackerMedianFlow_create,
cv2.TrackerMOSSE_create, cv2.TrackerGOTURN_create]
tracker_map = dict(zip(tracker_types, tracker_generator))
tracker_method = tracker_map[self.method_name]
return tracker_method
def update_multi_tracker(self, image, boxes):
self.multi_tracker = cv2.MultiTracker_create()
# 必须对每个对象实例化一个追踪器才可以
tracker_list = []
for i in range(len(boxes)):
tracker_list.append(self.tracker_method())
x = boxes[i][0]
y = boxes[i][1]
w = boxes[i][2]
h = boxes[i][3]
# 追踪器方法, 图片, tuple(x, y, w, h)
ok = self.multi_tracker.add(tracker_list[i], image, tuple([x, y, w, h]))
# Update tracker
ok, bboxes = self.multi_tracker.update(image)
print("-------- update ---------")
print(bboxes)
print("-------------------------")
if not ok:
raise NotImplementedError("Initial multiple-target tracker failed !")
@GetFps
def trace(self, frame):
ok, boxes = self.multi_tracker.update(frame)
if ok:
return boxes
else:
raise Exception("Tracking failure detected !")
def view_result(tracker_method, frame, boxes, fps):
for newbox in boxes:
p1 = (int(newbox[0]), int(newbox[1]))
p2 = (int(newbox[0] + newbox[2]), int(newbox[1] + newbox[3]))
cv2.rectangle(frame, p1, p2, (200, 0, 0))
# Display tracker type on frame 距离左上角的w和h
cv2.putText(frame, tracker_method + " Tracker", (100, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.75, (50, 170, 50), 2)
# Display FPS on frame
cv2.putText(frame, " FPS : " + str(int(fps)), (100, 50), cv2.FONT_HERSHEY_SIMPLEX, 0.75, (50, 170, 50), 2)
return frame
if __name__ == "__main__":
mtt = MultipleTargetTracker()
mtt.method_name = "KCF"
调用的demo.py,指定一个需要多个追踪对象的视频,认为标定4个roi信息即可开始执行追踪。
# coding:utf-8
import sys
import cv2 as cv
import numpy as np
# if len(sys.argv) != 2:
# print('Input video name is missing')
# exit()
print('Select 3 tracking targets')
cv.namedWindow("tracking")
camera = cv.VideoCapture('../test_data/hiv30.mp4')
tracker = cv.MultiTracker_create()
init_once = False
second_once = False
ok, image = camera.read()
if not ok:
print('Failed to read video')
exit()
bbox1 = cv.selectROI(image, False)
bbox2 = cv.selectROI(image, False)
bbox3 = cv.selectROI(image, False)
bbox4 = cv.selectROI(image, False)
print(type(bbox1))
print(bbox1)
newboxs = []
while camera.isOpened():
ok, image = camera.read()
if not ok:
print('no image to read')
break
if not init_once:
ok = tracker.add(cv.TrackerKCF_create(), image, bbox1)
ok = tracker.add(cv.TrackerKCF_create(), image, bbox2)
ok = tracker.add(cv.TrackerKCF_create(), image, bbox3)
init_once = True
ok, boxes = tracker.update(image)
print(ok, boxes)
if not second_once:
ok = tracker.add(cv.TrackerKCF_create(), image, bbox4)
second_once = True
ok, boxes = tracker.update(image)
if ok:
newboxs = boxes
if not ok:
del tracker
print('tracker 消失')
tracker = cv.MultiTracker_create()
for box in newboxs:
box = tuple(box)
ok = tracker.add(cv.TrackerKCF_create(), image, box)
for newbox in boxes:
p1 = (int(newbox[0]), int(newbox[1]))
p2 = (int(newbox[0] + newbox[2]), int(newbox[1] + newbox[3]))
cv.rectangle(image, p1, p2, (200, 0, 0))
cv.imshow('tracking', image)
k = cv.waitKey(1)
if k == 27: break # esc pressed