<think>
我们注意到代码中有一个问题:在`update_frame`函数中,我们尝试解析串口数据,但是去掉了对数据方括号的检查(注释掉了`data.startswith`和`endswith`的检查),直接分割字符串。这可能会导致解析错误,因为数据可能不完整。
另外,在`EggCounter`类的`_match_and_track`方法中使用了`linear_sum_assignment`,但在主程序中没有导入,我们已经在最后从`scipy.optimize`导入了,但为了确保代码健壮,应该在文件开头导入。
此外,我们还需要考虑串口读取的数据可能包含不完整的行或错误的数据格式。
改进点:
1. 确保数据解析的健壮性:使用更安全的方式解析串口数据,即使数据没有方括号,但必须是64个逗号分隔的数字。
2. 在解析数据时,增加异常处理,跳过无法转换为浮点数的数据。
3. 考虑在`update_frame`中,如果读取的数据不是64个,则跳过该帧。
另外,我们注意到在`update_frame`函数中,我们使用了`ax`参数,但在函数内部并没有使用它。我们可以移除它,或者保留以备扩展。
我们还将调整一些参数,比如串口超时时间改为1秒(在`init_serial`中已经设置),以及调整界面显示。
由于代码较长,我们只修改关键部分。
修改1:在`update_frame`中,更健壮地解析数据,允许没有方括号,但必须是64个数字。如果分割后不是64个,跳过。
修改2:在转换每个数字时,捕获转换异常,并跳过该帧。
修改3:确保`linear_sum_assignment`在`EggCounter`类的方法中使用之前已经导入。目前我们在文件最后导入了,但最好将其放在文件顶部。
修改4:移除`update_frame`函数中未使用的`ax`参数(但注意在动画函数中传递的参数也要相应移除)。或者我们保留它,因为我们在函数内部使用了`ax2`(注意:函数参数中有一个`ax`,但在函数内我们使用的是`ax2`,这可能是笔误?)。实际上,在`init_plot`中返回的`ax2`是第二个轴,在`update_frame`函数中我们将其命名为`ax`?但函数参数中命名为`ax`,而在函数内部我们将其作为`ax2`使用?这显然不对。检查函数定义:
在函数定义的参数中,我们写的是`ax`,但在函数体内,我们将其用作`ax2`?不,函数体内并没有使用`ax`,而是使用了`ax2`(这是外部传入的,实际上是`init_plot`返回的`ax2`)。所以这里参数名`ax`实际上指的是`init_plot`返回的`ax2`,但为了避免混淆,我们将参数名改为`ax2`。
因此,我们调整参数名:
将`update_frame`函数的参数`ax`改为`ax2`,并且在调用时也做相应改变(在`fargs`中传递的是`ax2`,即`init_plot`返回的第三个值)。
目前,在`init_plot`返回的顺序是:`fig, img, ax2, count_text, fps_text, track_text, egg_points`
在动画中传递的参数顺序:`fargs=(ser, img, ax2, count_text, fps_text, track_text, egg_points, counter, last_time)`
所以,在`update_frame`函数中,第三个参数是`ax2`(即位置图中的轴),我们在函数体内使用它来绘制。但是,在函数体内,我们并没有直接使用`ax2`,而是使用了`egg_points`(一个散点图对象)来更新位置。所以,`ax2`参数实际上在函数体内没有被使用,我们可以移除它。但是,为了以后扩展(比如在图上绘制其他内容),我们保留它。
不过,为了代码的一致性,我们保留参数,但不使用它也是可以的。
修改后的代码:
我们将进行以下调整:
- 在文件开头导入`linear_sum_assignment`
- 修改`update_frame`函数中的参数名称(将`ax`改为`ax2`)
- 在解析数据时,如果分割后不是64个元素,则跳过
- 在将每个字符串转换为浮点数时,捕获异常并跳过该帧
但是,由于代码较长,我们只修改关键部分。下面是修改后的完整代码:
注意:由于原代码中使用了`ax`作为参数名,但函数内未使用,我们将其改为`ax2`,并确保传递的是`ax2`(即位置图轴)。
但是,在`update_frame`函数内,我们并没有使用这个`ax2`,所以实际上我们可以移除这个参数。但是,为了与之前保持一致,我们保留它,但不再使用(或者用于调试绘图)。我们暂时保留。
修改后的`update_frame`函数中数据解析部分:
另外,我们注意到串口读取的`readline`可能会读到空行或部分行,因此我们增加检查。
由于代码较长,下面给出修改后的完整代码(只修改了关键部分):
</think>
以下是对代码的改进版本,修复了数据解析问题并优化了UI显示:
```python
import serial
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import time
from matplotlib.patches import Rectangle, Circle
from scipy.optimize import linear_sum_assignment # 确保导入
from matplotlib.font_manager import FontProperties
# 设置中文字体支持
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
class EggCounter:
def __init__(self):
self.grid_size = (8, 8)
self.cover_width = 100 # 宽度100mm
self.cell_size = self.cover_width / 8
self.background = None
self.bg_alpha = 0.1
self.next_id = 0
self.tracked_eggs = [] # 跟踪的鸡蛋对象
self.max_missed = 5
self.depth_threshold = 15
self.min_points = 3
self.exit_threshold = 30 # 右侧出口触发位置(mm),传送方向从左到右
self.total_count = 0
self.speed_estimation = [] # 用于速度估计的历史位置
self.frame_count = 0
self.start_time = time.time()
def _grid_to_physical(self, i, j):
"""将网格坐标转换为物理坐标(传送方向从左到右)"""
x = (j - 3.5) * self.cell_size # X: 传送方向(从左到右)
y = (i - 3.5) * self.cell_size # Y: 宽度方向(垂直于传送)
return x, y
def _cluster_points(self, mask):
"""对前景点进行聚类并返回中心坐标"""
clusters = []
visited = np.zeros_like(mask, dtype=bool)
def dfs(i, j, points):
stack = [(i, j)]
while stack:
x, y = stack.pop()
if visited[x, y] or not mask[x, y]:
continue
visited[x, y] = True
points.append((x, y))
for dx in (-1, 0, 1):
for dy in (-1, 0, 1):
nx, ny = x+dx, y+dy
if 0 <= nx < 8 and 0 <= ny < 8:
stack.append((nx, ny))
return points
for i in range(8):
for j in range(8):
if mask[i, j] and not visited[i, j]:
cluster = dfs(i, j, [])
if len(cluster) >= self.min_points:
centers = [self._grid_to_physical(*p) for p in cluster]
cx = np.mean([c[0] for c in centers])
cy = np.mean([c[1] for c in centers])
clusters.append((cx, cy))
return clusters
def _update_background(self, frame):
"""动态更新背景模型"""
if self.background is None:
self.background = frame.copy()
else:
foreground = (self.background - frame) > self.depth_threshold
self.background[~foreground] = (
(1 - self.bg_alpha) * self.background[~foreground]
+ self.bg_alpha * frame[~foreground]
)
def _match_and_track(self, current_eggs):
"""匹配当前帧检测到的鸡蛋与跟踪目标"""
# 如果没有跟踪目标,全部初始化为新目标
if not self.tracked_eggs:
for pos in current_eggs:
self.tracked_eggs.append({
'id': self.next_id,
'pos': pos,
'history': [pos], # 位置历史用于速度估计
'counted': False,
'missed': 0
})
self.next_id += 1
return
# 计算目标间的距离成本矩阵
cost_matrix = np.zeros((len(self.tracked_eggs), len(current_eggs)))
for i, egg in enumerate(self.tracked_eggs):
for j, pos in enumerate(current_eggs):
# 使用欧氏距离作为匹配成本
dist = np.sqrt((egg['pos'][0]-pos[0])**2 + (egg['pos'][1]-pos[1])**2)
cost_matrix[i, j] = dist
# 使用匈牙利算法进行匹配
row_idx, col_idx = linear_sum_assignment(cost_matrix)
# 更新匹配的目标
matched_targets = set()
matched_detections = set()
for i, j in zip(row_idx, col_idx):
if cost_matrix[i, j] < 20: # 距离阈值20mm
self.tracked_eggs[i]['pos'] = current_eggs[j]
self.tracked_eggs[i]['history'].append(current_eggs[j])
if len(self.tracked_eggs[i]['history']) > 5:
self.tracked_eggs[i]['history'].pop(0)
self.tracked_eggs[i]['missed'] = 0
matched_targets.add(i)
matched_detections.add(j)
# 处理未匹配的目标(丢失)
for i, egg in enumerate(self.tracked_eggs):
if i not in matched_targets:
egg['missed'] += 1
# 如果丢失次数过多,移除目标
if egg['missed'] > self.max_missed:
self.tracked_eggs.remove(egg)
# 添加新的检测目标
for j, pos in enumerate(current_eggs):
if j not in matched_detections:
self.tracked_eggs.append({
'id': self.next_id,
'pos': pos,
'history': [pos],
'counted': False,
'missed': 0
})
self.next_id += 1
def process_frame(self, frame):
"""
处理一帧深度数据并返回检测到的鸡蛋位置
:param frame: 8x8深度矩阵 (单位:mm)
:return: (当前检测到的鸡蛋位置, 总计数)
"""
self.frame_count += 1
# 1. 更新背景模型
self._update_background(frame)
# 2. 前景检测 (鸡蛋深度小于背景)
foreground = (self.background - frame) > self.depth_threshold
# 3. 聚类检测到的鸡蛋
current_eggs = self._cluster_points(foreground)
# 4. 目标跟踪与匹配
self._match_and_track(current_eggs)
# 5. 计数逻辑(鸡蛋从右侧离开时计数)
for egg in self.tracked_eggs.copy():
# 如果鸡蛋在出口区域且未计数
if not egg['counted'] and egg['pos'][0] >= self.exit_threshold:
self.total_count += 1
egg['counted'] = True
# 标记为已计数但继续跟踪直到离开视野
# 移出检测区域(右侧太远或宽度方向超出范围)
if abs(egg['pos'][1]) > 50 or egg['pos'][0] > 60:
self.tracked_eggs.remove(egg)
return [egg['pos'] for egg in self.tracked_eggs if not egg['counted']], self.total_count
def get_fps(self):
"""计算平均帧率"""
elapsed = time.time() - self.start_time
return self.frame_count / elapsed if elapsed > 0 else 0
def init_serial(port='COM12', baudrate=115200):
"""初始化串口连接"""
return serial.Serial(port, baudrate, timeout=1)
def create_blackwhite_colormap():
"""创建黑白渐变色图"""
colors = [(0.95, 0.95, 0.95), (0, 0, 0)] # 浅灰到黑
return plt.cm.colors.LinearSegmentedColormap.from_list('bw', colors, N=256)
def init_plot():
"""初始化深度图显示(传送方向从左到右)"""
fig = plt.figure(figsize=(14, 7), dpi=100)
fig.suptitle('VL53L5CX 鸡蛋计数系统(传送方向:左→右)', fontsize=16)
# 深度图区域
ax1 = fig.add_subplot(121)
cmap = create_blackwhite_colormap()
img = ax1.imshow(np.zeros((8,8)), cmap=cmap, vmin=0, vmax=500)
cbar = plt.colorbar(img, label='距离(mm)', ax=ax1)
cbar.ax.tick_params(labelsize=8) # 调整色标字体大小
ax1.set_title('8x8 深度图', fontsize=12)
ax1.set_xlabel('传送方向(左→右)', fontsize=10)
ax1.set_ylabel('宽度方向', fontsize=10)
ax1.tick_params(axis='both', which='major', labelsize=8)
# 添加传送带表示(红色表示出口)
ax1.add_patch(Rectangle([-0.5, -0.5], 8, 8, fill=False, edgecolor='red', linestyle='--'))
ax1.text(7.5, 3.5, '→', color='red', fontsize=20, ha='center') # 传送方向箭头
# 鸡蛋位置可视化区域
ax2 = fig.add_subplot(122)
ax2.set_xlim(-60, 60)
ax2.set_ylim(-60, 60)
ax2.set_title('鸡蛋位置检测', fontsize=12)
ax2.set_xlabel('传送方向(mm) (左→右)', fontsize=10)
ax2.set_ylabel('宽度方向(mm)', fontsize=10)
ax2.tick_params(axis='both', which='major', labelsize=8)
# 添加检测区域和计数线
ax2.add_patch(Rectangle([-50, -50], 100, 100, fill=False, edgecolor='blue'))
ax2.axvline(x=30, color='r', linestyle='--', alpha=0.7) # 出口计数线
ax2.text(32, -45, '计数线', color='r', rotation=90, fontsize=9)
# 添加传送方向箭头
ax2.arrow(-40, 50, 70, 0, head_width=5, head_length=5, fc='r', ec='r')
ax2.text(0, 55, '传送方向', color='r', ha='center', fontsize=10)
# 鸡蛋计数显示
count_text = ax2.text(0, 45, '鸡蛋计数: 0',
ha='center', va='center',
fontsize=24, color='green',
bbox=dict(facecolor='white', alpha=0.8))
# 帧率显示
fps_text = ax2.text(0, -45, 'FPS: 0.0',
ha='center', va='center',
fontsize=12, color='blue')
# 当前跟踪目标显示
track_text = ax2.text(40, -45, '跟踪目标: 0',
ha='left', va='center',
fontsize=10, color='purple')
# 鸡蛋位置点容器
egg_points = ax2.scatter([], [], s=100, c='red', edgecolors='black', alpha=0.7)
# 添加状态信息面板
status_text = ax2.text(-55, 55,
"系统状态: 运行中\n"
"深度阈值: 15mm\n"
"最小点数: 3",
ha='left', va='top',
fontsize=9,
bbox=dict(facecolor='white', alpha=0.7))
return fig, img, ax2, count_text, fps_text, track_text, egg_points, status_text
def parse_serial_data(data):
"""解析串口数据,支持多种格式"""
# 尝试移除可能的方括号
if data.startswith('[') and data.endswith(']'):
data = data[1:-1]
# 分割数据
parts = data.split(',')
# 检查数据长度
if len(parts) != 64:
return None
# 转换为浮点数
try:
distances = list(map(float, parts))
return distances
except ValueError:
return None
def update_frame(frame_num, ser, img, ax2, count_text, fps_text, track_text, egg_points, status_text, counter, last_time):
"""更新深度图帧并处理鸡蛋计数(传送方向从左到右)"""
current_time = time.time()
elapsed = current_time - last_time[0]
fps = 1 / elapsed if elapsed > 0 else 0
last_time[0] = current_time
try:
# 读取串口数据
data = ser.readline().decode('utf-8', errors='ignore').strip()
if data:
# 解析数据
distances = parse_serial_data(data)
if distances:
depth_map = np.array(distances).reshape(8,8)
# 更新深度图显示
img.set_array(depth_map)
img.autoscale()
# 鸡蛋检测与计数
eggs, total_count = counter.process_frame(depth_map)
# 更新鸡蛋位置散点图
if eggs:
x_pos = [egg[0] for egg in eggs]
y_pos = [egg[1] for egg in eggs]
egg_points.set_offsets(np.column_stack([x_pos, y_pos]))
else:
egg_points.set_offsets([])
# 更新文本显示
count_text.set_text(f'鸡蛋计数: {total_count}')
fps_text.set_text(f'FPS: {fps:.1f}')
track_text.set_text(f'跟踪目标: {len(counter.tracked_eggs)}')
# 更新状态信息
avg_fps = counter.get_fps()
status_text.set_text(
f"系统状态: 运行中\n"
f"深度阈值: {counter.depth_threshold}mm\n"
f"最小点数: {counter.min_points}\n"
f"平均FPS: {avg_fps:.1f}\n"
f"总帧数: {counter.frame_count}"
)
except Exception as e:
print(f"Error: {e}")
return img, egg_points, count_text, fps_text, track_text, status_text
def main():
try:
ser = init_serial()
counter = EggCounter()
fig, img, ax2, count_text, fps_text, track_text, egg_points, status_text = init_plot()
last_time = [time.time()]
ani = animation.FuncAnimation(
fig, update_frame,
fargs=(ser, img, ax2, count_text, fps_text, track_text, egg_points, status_text, counter, last_time),
interval=50, blit=True
)
plt.tight_layout()
plt.subplots_adjust(top=0.9)
plt.show()
except serial.SerialException as se:
print(f"串口连接错误: {se}")
print("请检查端口号和连接状态")
finally:
if 'ser' in locals() and ser.is_open:
ser.close()
if __name__ == "__main__":
main()
```
### 主要改进内容:
#### 1. 数据解析优化
```python
def parse_serial_data(data):
"""解析串口数据,支持多种格式"""
# 尝试移除可能的方括号
if data.startswith('[') and data.endswith(']'):
data = data[1:-1]
# 分割数据
parts = data.split(',')
# 检查数据长度
if len(parts) != 64:
return None
# 转换为浮点数
try:
distances = list(map(float, parts))
return distances
except ValueError:
return None
```
- 添加了专门的数据解析函数
- 支持带方括号和不带方括号的数据格式
- 更强健的错误处理,忽略无效数据帧
#### 2. 串口通信增强
```python
# 读取数据时添加错误处理
data = ser.readline().decode('utf-8', errors='ignore').strip()
# 主函数添加错误处理
try:
ser = init_serial()
# ...其余代码...
except serial.SerialException as se:
print(f"串口连接错误: {se}")
print("请检查端口号和连接状态")
finally:
if 'ser' in locals() and ser.is_open:
ser.close()
```
- 使用`errors='ignore'`处理非UTF-8字符
- 添加串口异常捕获和处理
- 确保程序退出时关闭串口
#### 3. UI界面优化
```python
# 添加状态信息面板
status_text = ax2.text(-55, 55,
"系统状态: 运行中\n"
"深度阈值: 15mm\n"
"最小点数: 3",
ha='left', va='top',
fontsize=9,
bbox=dict(facecolor='white', alpha=0.7))
# 帧率计算优化
def get_fps(self):
"""计算平均帧率"""
elapsed = time.time() - self.start_time
return self.frame_count / elapsed if elapsed > 0 else 0
```
- 添加系统状态面板,显示关键参数
- 计算并显示平均帧率(非瞬时F