在制作0-99手写数字模型时,百分类的数据集规模可能偏大,数据集采集训练复杂。可以尝试十分类方案,以下是思路分享:
使用yolo训练目标检测模型
对两位数字进行分割,检测分出单个数字。进行划分后,之后的目标分类可以识别划分出的矩形区域。这样缩小识别范围,可以提高识别准确率。如下图所示:




训练分类模型
需训练数字0-9及其四个方向(存在对称的数字(0,1,8)不进行方向分类,四个方向的数字归到一个文件夹里),如下图。训练的数据集可以使用mnist数据集(但mnist数据集存在线条偏粗等缺点,需要对数据做特殊处理),或者使用自己采集的手写数据集。

两个单个数字推广到两位数的逻辑
逻辑的关键是确定哪个数字是十位,哪个数字是个位。我设计的逻辑如下:
①目标检测如果得到一个目标,直接进行目标分类识别,将识别到的结果即为最终结果输出。
②如果目标检测得到两个目标输出,进行目标分类识别,分别记录两个识别结果。如果两个结果中有一个结果是0,则认为0是个位,另一张图片的结果是十位。如果没有0接着进行,分别得到两个目标检测后矩形的中心点坐标(x,y),比较两中心点坐标的x和y之间差值,谁大,选择哪个方向(x方向或y方向)。Ⅰ如果x的差值大,选择x的值小的目标图片。此图片识别结果如果为2L,3L,4L,5L,6L,7L,9L,则认为这张目标检测的结果是数字十位,另一张图片是数字个位。此图片识别结果如果为2R,3R,4R,5R,6R,7R,9R,则认为这张目标检测的结果是数字个位,另一张图片是数字十位。Ⅱ如果y的差值大,选择y的值小的目标图片。此图片识别结果如果为2D,3D,4D,5D,6D,7D,9D,则认为这张目标检测的结果是数字十位,另一张图片是数字个位。此图片识别结果如果为2U,3U,4U,5U,6U,7U,9U,则认为这张目标检测的结果是数字个位,另一张图片是数字十位。如果识别到为1,8,再检测另一张目标图片,得出个,十位结果规则与第一次的相反。第二张图片如果再识别出1,8,默认第一个图是十位,第二张图结果为个位。 根据得到的两张目标图片的个,十位结果,得到一个两位数输出(10*十位+个位)
③如果目标检测有多于两个目标,则输出100(设置的默认值)
注释1:nL(2L,3L,4L,5L,6L,7L,9L)表示开口向左(left),也是我们最常见的数字朝向。
nU(2U,3U,4U,5U,6U,7U,9U)表示开口向上(up)。
nR(2R,3R,4R,5R,6R,7R,9R)表示开口向右(right)。
nD(2D,3D,4D,5D,6D,7D,9D)表示开口向下(down)。具体如下图:

注释2:在整个画面内,左上角的坐标为(0,0),右下角的坐标为(img.width(),img.height())。
完整代码(MicroPython)
import pyb
import sensor, image, time, tf, gc
import os, math
# 初始化摄像头
sensor.reset()
sensor.set_pixformat(sensor.RGB565)
sensor.set_framesize(sensor.QVGA) # 320x240
sensor.set_brightness(800)
sensor.skip_frames(time=2000)
sensor.set_auto_whitebal(True, (0, 0x80, 0))
clock = time.clock()
# ====== 加载目标检测模型 ======
detect_model = '/sd/yolo5.tflite'
detect_net = tf.load(detect_model)
# ====== 加载分类模型 ====== mobilenet-
classify_model = "/sd/mobilenet2.tflite"
sensor.set_auto_gain(False)
classify_labels = [line.rstrip() for line in open("/sd/number.txt")]
classify_net = tf.load(classify_model, load_to_fb=True)
# 辅助函数:从标签字符串中提取数字和方向
def parse_label(label_str):
num_part = ''
dir_part = ''
for char in label_str:
if char.isdigit():
num_part += char
else:
dir_part += char
return int(num_part) if num_part else 0, dir_part
# ====== 主循环 ======
while True:
clock.tick()
img = sensor.snapshot()
detected_objects = [] # 存储检测结果用于终端输出
final_number = None # 最终输出的数字
# 第一阶段:目标检测
for obj in tf.detect(detect_net, img):
x1, y1, x2, y2, label, score = obj
if score > 0.70: # 置信度阈值
#扩大检测范围
x3=x1-0.01
y3=y1-0.01
x4=x2+0.01
y4=y2+0.01
# 转换归一化坐标为像素坐标
x = int(x3 * 240)
y = int(y3 * img.height())
w = int((x4 - x3) * 240)
h = int((y4 - y3) * img.height())
# 确保边界框在图像范围内
x = max(0, x)
y = max(0, y)
w = min(w, img.width() - x)
h = min(h, img.height() - y)
if w <= 5 or h <= 5: # 跳过无效区域
continue
# 绘制检测框
img.draw_rectangle((x, y, w, h), color=(0, 255, 0), thickness=2)
center_x = x + w // 2
center_y = y + h // 2
img.draw_cross(center_x, center_y, color=(0, 255, 255), size=5)
# 记录检测信息
detect_info = {
'position': (x, y, w, h),
'center': (center_x, center_y),
'detect_score': score
}
# 第二阶段:在检测框内进行分类
roi = (x, y, w, h)
roi_img = img.copy(roi=roi) # 复制ROI区域
# 执行分类
for obj_class in tf.classify(classify_net, roi_img):
outputs = obj_class.output()
max_idx = outputs.index(max(outputs))
label_str = classify_labels[max_idx]
confidence = outputs[max_idx]
# 显示分类结果
text = "{}:{:.1f}%".format(label_str, confidence*100)
img.draw_string(x, y-10, text, color=(255, 0, 0), scale=1.5)
# 解析标签
class_num, class_dir = parse_label(label_str)
# 记录分类信息
detect_info['class_str'] = label_str
detect_info['class_num'] = class_num
detect_info['class_dir'] = class_dir
detect_info['class_confidence'] = confidence*100
# 添加到检测结果列表
detected_objects.append(detect_info)
# 释放内存
del roi_img
gc.collect()
# 根据目标数量处理数字逻辑
num_objects = len(detected_objects)
if num_objects == 1:
# 单个目标直接输出
final_number = detected_objects[0]['class_num']
print("Single object detected, output:", final_number)
elif num_objects == 2:
obj1, obj2 = detected_objects[0], detected_objects[1]
num1, num2 = obj1['class_num'], obj2['class_num']
dir1, dir2 = obj1['class_dir'], obj2['class_dir']
print(f"Two objects: {num1}{dir1} and {num2}{dir2}")
# 检查是否有0
if num1 == 0 or num2 == 0:
ten_digit = num1 if num1 != 0 else num2
unit_digit = 0
final_number = ten_digit * 10 + unit_digit
print("One object is 0, output:", final_number)
else:
# 计算中心点差值
dx = abs(obj1['center'][0] - obj2['center'][0])
dy = abs(obj1['center'][1] - obj2['center'][1])
print(f"Center differences - dx: {dx}, dy: {dy}")
if dx > dy: # 水平方向差异更大
# 选择x坐标更小的对象(更靠左)
if obj1['center'][0] < obj2['center'][0]:
left_obj, right_obj = obj1, obj2
else:
left_obj, right_obj = obj2, obj1
print(f"Horizontal priority - Left object: {left_obj['class_str']}")
# 检查特殊数字 (1或8)
if left_obj['class_num'] in [1, 8]:
ten_digit = right_obj['class_num']
unit_digit = left_obj['class_num']
print("Left is 1/8, reverse assignment")
else:
if 'L' in left_obj['class_dir']:
ten_digit = left_obj['class_num']
unit_digit = right_obj['class_num']
elif 'R' in left_obj['class_dir']:
ten_digit = right_obj['class_num']
unit_digit = left_obj['class_num']
else: # 默认左侧为十位
ten_digit = left_obj['class_num']
unit_digit = right_obj['class_num']
final_number = ten_digit * 10 + unit_digit
else: # 垂直方向差异更大
# 选择y坐标更小的对象(更靠上)
if obj1['center'][1] < obj2['center'][1]:
top_obj, bottom_obj = obj1, obj2
else:
top_obj, bottom_obj = obj2, obj1
print(f"Vertical priority - Top object: {top_obj['class_str']}")
# 检查特殊数字 (1或8)
if top_obj['class_num'] in [1, 8]:
ten_digit = bottom_obj['class_num']
unit_digit = top_obj['class_num']
print("Top is 1/8, reverse assignment")
else:
if 'U' in top_obj['class_dir']:
ten_digit = bottom_obj['class_num']
unit_digit = top_obj['class_num']
elif 'D' in top_obj['class_dir']:
ten_digit = top_obj['class_num']
unit_digit = bottom_obj['class_num']
else: # 默认上方为十位
ten_digit = top_obj['class_num']
unit_digit = bottom_obj['class_num']
final_number = ten_digit * 10 + unit_digit
print(f"Two objects combined: {final_number}")
elif num_objects > 2:
final_number = 100
print(f"Multiple objects ({num_objects}), output: 100")
# 终端输出结果
print("\n------ Frame Info ------")
print("FPS: {:.1f}".format(clock.fps()))
print(f"Detected objects: {num_objects}")
if detected_objects:
print("\nDetected Objects:")
for i, obj in enumerate(detected_objects, 1):
x, y, w, h = obj['position']
cx, cy = obj['center']
print(f"Object {i}:")
print(f" Position: ({x}, {y}) Size: {w}x{h}")
print(f" Center: ({cx}, {cy})")
print(f" Detect Score: {obj['detect_score']:.2f}")
print(f" Class: {obj['class_str']} (Num: {obj['class_num']}, Dir: {obj['class_dir']})")
print(f" Class Confidence: {obj['class_confidence']:.1f}%")
else:
print("No objects detected")
if final_number is not None:
print(f"Final Output Number: {final_number}")
print("-----------------------")
# 显示帧率和最终结果
fps_text = "FPS:{:.1f}".format(clock.fps())
img.draw_string(5, 5, fps_text, color=(0, 0, 255), scale=2.0)
if final_number is not None:
result_text = "Num:{}".format(final_number)
img.draw_string(img.width()//2-30, img.height()-30, result_text, color=(255, 0, 0), scale=2.0)
&spm=1001.2101.3001.5002&articleId=150402921&d=1&t=3&u=c74dce77172540ad9fc1cc8f18ff5b39)
2118

被折叠的 条评论
为什么被折叠?



