import serial
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import time
from matplotlib.patches import Rectangle, Circle
from pylab import mpl
mpl.rcParams["font.sans-serif"] = ["SimHei"]
mpl.rcParams["axes.unicode_minus"] = False
class EggCounter:
def __init__(self):
self.grid_size = (8, 8)
self.cover_width = 100 # 宽度100mm
self.cell_size = self.cover_width / 8
self.background = None
self.bg_alpha = 0.1
self.next_id = 0
self.tracked_eggs = [] # 跟踪的鸡蛋对象
self.max_missed = 5
self.depth_threshold = 15
self.min_points = 3
self.exit_threshold = 30 # 右侧出口触发位置(mm),传送方向从左到右
self.total_count = 0
self.speed_estimation = [] # 用于速度估计的历史位置
def _grid_to_physical(self, i, j):
"""将网格坐标转换为物理坐标(传送方向从左到右)"""
x = (j - 3.5) * self.cell_size # X: 传送方向(从左到右)
y = (i - 3.5) * self.cell_size # Y: 宽度方向(垂直于传送)
return x, y
def _cluster_points(self, mask):
"""对前景点进行聚类并返回中心坐标"""
clusters = []
visited = np.zeros_like(mask, dtype=bool)
def dfs(i, j, points):
stack = [(i, j)]
while stack:
x, y = stack.pop()
if visited[x, y] or not mask[x, y]:
continue
visited[x, y] = True
points.append((x, y))
for dx in (-1, 0, 1):
for dy in (-1, 0, 1):
nx, ny = x+dx, y+dy
if 0 <= nx < 8 and 0 <= ny < 8:
stack.append((nx, ny))
return points
for i in range(8):
for j in range(8):
if mask[i, j] and not visited[i, j]:
cluster = dfs(i, j, [])
if len(cluster) >= self.min_points:
centers = [self._grid_to_physical(*p) for p in cluster]
cx = np.mean([c[0] for c in centers])
cy = np.mean([c[1] for c in centers])
clusters.append((cx, cy))
return clusters
def _update_background(self, frame):
"""动态更新背景模型"""
if self.background is None:
self.background = frame.copy()
else:
foreground = (self.background - frame) > self.depth_threshold
self.background[~foreground] = (
(1 - self.bg_alpha) * self.background[~foreground]
+ self.bg_alpha * frame[~foreground]
)
def _match_and_track(self, current_eggs):
"""匹配当前帧检测到的鸡蛋与跟踪目标"""
# 如果没有跟踪目标,全部初始化为新目标
if not self.tracked_eggs:
for pos in current_eggs:
self.tracked_eggs.append({
'id': self.next_id,
'pos': pos,
'history': [pos], # 位置历史用于速度估计
'counted': False,
'missed': 0
})
self.next_id += 1
return
# 计算目标间的距离成本矩阵
cost_matrix = np.zeros((len(self.tracked_eggs), len(current_eggs)))
for i, egg in enumerate(self.tracked_eggs):
for j, pos in enumerate(current_eggs):
# 使用欧氏距离作为匹配成本
dist = np.sqrt((egg['pos'][0]-pos[0])**2 + (egg['pos'][1]-pos[1])**2)
cost_matrix[i, j] = dist
# 使用匈牙利算法进行匹配
row_idx, col_idx = linear_sum_assignment(cost_matrix)
# 更新匹配的目标
matched_targets = set()
matched_detections = set()
for i, j in zip(row_idx, col_idx):
if cost_matrix[i, j] < 20: # 距离阈值20mm
self.tracked_eggs[i]['pos'] = current_eggs[j]
self.tracked_eggs[i]['history'].append(current_eggs[j])
if len(self.tracked_eggs[i]['history']) > 5:
self.tracked_eggs[i]['history'].pop(0)
self.tracked_eggs[i]['missed'] = 0
matched_targets.add(i)
matched_detections.add(j)
# 处理未匹配的目标(丢失)
for i, egg in enumerate(self.tracked_eggs):
if i not in matched_targets:
egg['missed'] += 1
# 添加新的检测目标
for j, pos in enumerate(current_eggs):
if j not in matched_detections:
self.tracked_eggs.append({
'id': self.next_id,
'pos': pos,
'history': [pos],
'counted': False,
'missed': 0
})
self.next_id += 1
def process_frame(self, frame):
"""
处理一帧深度数据并返回检测到的鸡蛋位置
:param frame: 8x8深度矩阵 (单位:mm)
:return: (当前检测到的鸡蛋位置, 总计数)
"""
# 1. 更新背景模型
self._update_background(frame)
# 2. 前景检测 (鸡蛋深度小于背景)
foreground = (self.background - frame) > self.depth_threshold
# 3. 聚类检测到的鸡蛋
current_eggs = self._cluster_points(foreground)
# 4. 目标跟踪与匹配
self._match_and_track(current_eggs)
# 5. 计数逻辑(鸡蛋从右侧离开时计数)
for egg in self.tracked_eggs.copy():
# 如果鸡蛋在出口区域且未计数
if not egg['counted'] and egg['pos'][0] >= self.exit_threshold:
self.total_count += 1
egg['counted'] = True
# 标记为已计数但继续跟踪直到离开视野
# 移出检测区域(右侧太远或宽度方向超出范围)
if abs(egg['pos'][1]) > 50 or egg['pos'][0] > 60:
self.tracked_eggs.remove(egg)
return [egg['pos'] for egg in self.tracked_eggs if not egg['counted']], self.total_count
def init_serial(port='COM12', baudrate=115200):
"""初始化串口连接"""
return serial.Serial(port, baudrate, timeout=10)
def create_blackwhite_colormap():
"""创建黑白渐变色图"""
colors = [(0.95, 0.95, 0.95), (0, 0, 0)] # 浅灰到黑
return plt.cm.colors.LinearSegmentedColormap.from_list('bw', colors, N=256)
def init_plot():
"""初始化深度图显示(传送方向从左到右)"""
fig = plt.figure(figsize=(14, 7))
fig.suptitle('VL53L5CX 鸡蛋计数系统(传送方向:左→右)', fontsize=16)
# 深度图区域
ax1 = fig.add_subplot(121)
cmap = create_blackwhite_colormap()
img = ax1.imshow(np.zeros((8,8)), cmap=cmap, vmin=0, vmax=500)
plt.colorbar(img, label='距离(mm)', ax=ax1)
ax1.set_title('8x8 深度图')
ax1.set_xlabel('传送方向(左→右)')
ax1.set_ylabel('宽度方向')
# 添加传送带表示(红色表示出口)
ax1.add_patch(Rectangle([-0.5, -0.5], 8, 8, fill=False, edgecolor='red', linestyle='--'))
ax1.text(7.5, 3.5, '→', color='red', fontsize=20, ha='center') # 传送方向箭头
# 鸡蛋位置可视化区域
ax2 = fig.add_subplot(122)
ax2.set_xlim(-60, 60)
ax2.set_ylim(-60, 60)
ax2.set_title('鸡蛋位置检测')
ax2.set_xlabel('传送方向(mm) (左→右)')
ax2.set_ylabel('宽度方向(mm)')
# 添加检测区域和计数线
ax2.add_patch(Rectangle([-50, -50], 100, 100, fill=False, edgecolor='blue'))
ax2.axvline(x=30, color='r', linestyle='--', alpha=0.7) # 出口计数线
ax2.text(32, -45, '计数线', color='r', rotation=90)
# 添加传送方向箭头
ax2.arrow(-40, 50, 70, 0, head_width=5, head_length=5, fc='r', ec='r')
ax2.text(0, 55, '传送方向', color='r', ha='center')
# 鸡蛋计数显示
count_text = ax2.text(0, 45, '鸡蛋计数: 0',
ha='center', va='center',
fontsize=24, color='green',
bbox=dict(facecolor='white', alpha=0.8))
# 帧率显示
fps_text = ax2.text(0, -45, 'FPS: 0.0',
ha='center', va='center',
fontsize=12, color='blue')
# 当前跟踪目标显示
track_text = ax2.text(40, -45, '跟踪目标: 0',
ha='left', va='center',
fontsize=12, color='purple')
# 鸡蛋位置点容器
egg_points = ax2.scatter([], [], s=100, c='red', edgecolors='black', alpha=0.7)
return fig, img, ax2, count_text, fps_text, track_text, egg_points
def update_frame(frame, ser, img, ax, count_text, fps_text, track_text, egg_points, counter, last_time):
"""更新深度图帧并处理鸡蛋计数(传送方向从左到右)"""
current_time = time.time()
fps = 1 / (current_time - last_time[0]) if current_time != last_time[0] else 0
last_time[0] = current_time
try:
# 读取串口数据
data = ser.readline().decode('utf-8').strip()
if data:# and data.startswith('[') and data.endswith(']'):
# 处理数据格式: [123.4, 125.6, ...]
#data = data[1:-1] # 移除方括号
distances = list(map(float, data.split(',')))
if len(distances) == 64:
depth_map = np.array(distances).reshape(8,8)
# 更新深度图显示
img.set_array(depth_map)
img.autoscale()
# 鸡蛋检测与计数
eggs, total_count = counter.process_frame(depth_map)
# 更新鸡蛋位置散点图
if eggs:
x_pos = [egg[0] for egg in eggs]
y_pos = [egg[1] for egg in eggs]
egg_points.set_offsets(np.column_stack([x_pos, y_pos]))
else:
egg_points.set_offsets([])
# 更新计数显示
count_text.set_text(f'鸡蛋计数: {total_count}')
fps_text.set_text(f'FPS: {fps:.1f}')
track_text.set_text(f'跟踪目标: {len(counter.tracked_eggs)}')
except Exception as e:
print(f"Error: {e}")
return img, egg_points, count_text, fps_text, track_text
def main():
ser = init_serial()
counter = EggCounter()
fig, img, ax2, count_text, fps_text, track_text, egg_points = init_plot()
last_time = [time.time()]
ani = animation.FuncAnimation(
fig, update_frame,
fargs=(ser, img, ax2, count_text, fps_text, track_text, egg_points, counter, last_time),
interval=50, blit=True
)
plt.tight_layout()
plt.subplots_adjust(top=0.9)
plt.show()
ser.close()
if __name__ == "__main__":
from scipy.optimize import linear_sum_assignment # 确保导入
main()