手写数字0-99的识别思路分享(10分类方案)

部署运行你感兴趣的模型镜像

在制作0-99手写数字模型时,百分类的数据集规模可能偏大,数据集采集训练复杂。可以尝试十分类方案,以下是思路分享:

使用yolo训练目标检测模型

        对两位数字进行分割,检测分出单个数字。进行划分后,之后的目标分类可以识别划分出的矩形区域。这样缩小识别范围,可以提高识别准确率。如下图所示:

训练分类模型

需训练数字0-9及其四个方向(存在对称的数字(0,1,8)不进行方向分类,四个方向的数字归到一个文件夹里),如下图。训练的数据集可以使用mnist数据集(但mnist数据集存在线条偏粗等缺点,需要对数据做特殊处理),或者使用自己采集的手写数据集。

两个单个数字推广到两位数的逻辑

逻辑的关键是确定哪个数字是十位,哪个数字是个位。我设计的逻辑如下:

    ①目标检测如果得到一个目标,直接进行目标分类识别,将识别到的结果即为最终结果输出。

    ②如果目标检测得到两个目标输出,进行目标分类识别,分别记录两个识别结果。如果两个结果中有一个结果是0,则认为0是个位,另一张图片的结果是十位。如果没有0接着进行,分别得到两个目标检测后矩形的中心点坐标(x,y),比较两中心点坐标的x和y之间差值,谁大,选择哪个方向(x方向或y方向)。Ⅰ如果x的差值大,选择x的值小的目标图片。此图片识别结果如果为2L,3L,4L,5L,6L,7L,9L,则认为这张目标检测的结果是数字十位,另一张图片是数字个位。此图片识别结果如果为2R,3R,4R,5R,6R,7R,9R,则认为这张目标检测的结果是数字个位,另一张图片是数字十位。Ⅱ如果y的差值大,选择y的值小的目标图片。此图片识别结果如果为2D,3D,4D,5D,6D,7D,9D,则认为这张目标检测的结果是数字十位,另一张图片是数字个位。此图片识别结果如果为2U,3U,4U,5U,6U,7U,9U,则认为这张目标检测的结果是数字个位,另一张图片是数字十位。如果识别到为1,8,再检测另一张目标图片,得出个,十位结果规则与第一次的相反。第二张图片如果再识别出1,8,默认第一个图是十位,第二张图结果为个位。 根据得到的两张目标图片的个,十位结果,得到一个两位数输出(10*十位+个位)

    ③如果目标检测有多于两个目标,则输出100(设置的默认值)

    注释1:nL(2L,3L,4L,5L,6L,7L,9L)表示开口向左(left),也是我们最常见的数字朝向。

            nU(2U,3U,4U,5U,6U,7U,9U)表示开口向上(up)。

            nR(2R,3R,4R,5R,6R,7R,9R)表示开口向右(right)。

            nD(2D,3D,4D,5D,6D,7D,9D)表示开口向下(down)。具体如下图:

    注释2:在整个画面内,左上角的坐标为(0,0),右下角的坐标为(img.width(),img.height())。

    完整代码(MicroPython)

    import pyb
    import sensor, image, time, tf, gc
    import os, math
    
    # 初始化摄像头
    sensor.reset()
    sensor.set_pixformat(sensor.RGB565)
    sensor.set_framesize(sensor.QVGA)  # 320x240
    sensor.set_brightness(800)
    sensor.skip_frames(time=2000)
    sensor.set_auto_whitebal(True, (0, 0x80, 0))
    clock = time.clock()
    
    # ====== 加载目标检测模型 ======
    detect_model = '/sd/yolo5.tflite'
    detect_net = tf.load(detect_model)
    
    # ====== 加载分类模型 ====== mobilenet-
    classify_model = "/sd/mobilenet2.tflite"
    sensor.set_auto_gain(False)
    classify_labels = [line.rstrip() for line in open("/sd/number.txt")]
    classify_net = tf.load(classify_model, load_to_fb=True)
    
    # 辅助函数:从标签字符串中提取数字和方向
    def parse_label(label_str):
        num_part = ''
        dir_part = ''
        for char in label_str:
            if char.isdigit():
                num_part += char
            else:
                dir_part += char
        return int(num_part) if num_part else 0, dir_part
    
    # ====== 主循环 ======
    while True:
        clock.tick()
        img = sensor.snapshot()
        detected_objects = []  # 存储检测结果用于终端输出
        final_number = None    # 最终输出的数字
    
        # 第一阶段:目标检测
        for obj in tf.detect(detect_net, img):
            x1, y1, x2, y2, label, score = obj
    
            if score > 0.70:  # 置信度阈值
    
                #扩大检测范围
                x3=x1-0.01
                y3=y1-0.01
                x4=x2+0.01
                y4=y2+0.01
    
                # 转换归一化坐标为像素坐标
                x = int(x3 * 240)
                y = int(y3 * img.height())
                w = int((x4 - x3) * 240)
                h = int((y4 - y3) * img.height())
    
                # 确保边界框在图像范围内
                x = max(0, x)
                y = max(0, y)
                w = min(w, img.width() - x)
                h = min(h, img.height() - y)
    
                if w <= 5 or h <= 5:  # 跳过无效区域
                    continue
    
                # 绘制检测框
                img.draw_rectangle((x, y, w, h), color=(0, 255, 0), thickness=2)
                center_x = x + w // 2
                center_y = y + h // 2
                img.draw_cross(center_x, center_y, color=(0, 255, 255), size=5)
    
                # 记录检测信息
                detect_info = {
                    'position': (x, y, w, h),
                    'center': (center_x, center_y),
                    'detect_score': score
                }
    
                # 第二阶段:在检测框内进行分类
                roi = (x, y, w, h)
                roi_img = img.copy(roi=roi)  # 复制ROI区域
    
                # 执行分类
                for obj_class in tf.classify(classify_net, roi_img):
                    outputs = obj_class.output()
                    max_idx = outputs.index(max(outputs))
                    label_str = classify_labels[max_idx]
                    confidence = outputs[max_idx]
    
                    # 显示分类结果
                    text = "{}:{:.1f}%".format(label_str, confidence*100)
                    img.draw_string(x, y-10, text, color=(255, 0, 0), scale=1.5)
    
                    # 解析标签
                    class_num, class_dir = parse_label(label_str)
    
                    # 记录分类信息
                    detect_info['class_str'] = label_str
                    detect_info['class_num'] = class_num
                    detect_info['class_dir'] = class_dir
                    detect_info['class_confidence'] = confidence*100
    
                # 添加到检测结果列表
                detected_objects.append(detect_info)
    
                # 释放内存
                del roi_img
                gc.collect()
    
        # 根据目标数量处理数字逻辑
        num_objects = len(detected_objects)
        if num_objects == 1:
            # 单个目标直接输出
            final_number = detected_objects[0]['class_num']
            print("Single object detected, output:", final_number)
    
        elif num_objects == 2:
            obj1, obj2 = detected_objects[0], detected_objects[1]
            num1, num2 = obj1['class_num'], obj2['class_num']
            dir1, dir2 = obj1['class_dir'], obj2['class_dir']
    
            print(f"Two objects: {num1}{dir1} and {num2}{dir2}")
    
            # 检查是否有0
            if num1 == 0 or num2 == 0:
                ten_digit = num1 if num1 != 0 else num2
                unit_digit = 0
                final_number = ten_digit * 10 + unit_digit
                print("One object is 0, output:", final_number)
            else:
                # 计算中心点差值
                dx = abs(obj1['center'][0] - obj2['center'][0])
                dy = abs(obj1['center'][1] - obj2['center'][1])
                print(f"Center differences - dx: {dx}, dy: {dy}")
    
                if dx > dy:  # 水平方向差异更大
                    # 选择x坐标更小的对象(更靠左)
                    if obj1['center'][0] < obj2['center'][0]:
                        left_obj, right_obj = obj1, obj2
                    else:
                        left_obj, right_obj = obj2, obj1
    
                    print(f"Horizontal priority - Left object: {left_obj['class_str']}")
    
                    # 检查特殊数字 (1或8)
                    if left_obj['class_num'] in [1, 8]:
                        ten_digit = right_obj['class_num']
                        unit_digit = left_obj['class_num']
                        print("Left is 1/8, reverse assignment")
                    else:
                        if 'L' in left_obj['class_dir']:
                            ten_digit = left_obj['class_num']
                            unit_digit = right_obj['class_num']
                        elif 'R' in left_obj['class_dir']:
                            ten_digit = right_obj['class_num']
                            unit_digit = left_obj['class_num']
                        else:  # 默认左侧为十位
                            ten_digit = left_obj['class_num']
                            unit_digit = right_obj['class_num']
    
                    final_number = ten_digit * 10 + unit_digit
    
                else:  # 垂直方向差异更大
                    # 选择y坐标更小的对象(更靠上)
                    if obj1['center'][1] < obj2['center'][1]:
                        top_obj, bottom_obj = obj1, obj2
                    else:
                        top_obj, bottom_obj = obj2, obj1
    
                    print(f"Vertical priority - Top object: {top_obj['class_str']}")
    
                    # 检查特殊数字 (1或8)
                    if top_obj['class_num'] in [1, 8]:
                        ten_digit = bottom_obj['class_num']
                        unit_digit = top_obj['class_num']
                        print("Top is 1/8, reverse assignment")
                    else:
                        if 'U' in top_obj['class_dir']:
                            ten_digit = bottom_obj['class_num']
                            unit_digit = top_obj['class_num']
                        elif 'D' in top_obj['class_dir']:
                            ten_digit = top_obj['class_num']
                            unit_digit = bottom_obj['class_num']
                        else:  # 默认上方为十位
                            ten_digit = top_obj['class_num']
                            unit_digit = bottom_obj['class_num']
    
                    final_number = ten_digit * 10 + unit_digit
    
                print(f"Two objects combined: {final_number}")
    
        elif num_objects > 2:
            final_number = 100
            print(f"Multiple objects ({num_objects}), output: 100")
    
        # 终端输出结果
        print("\n------ Frame Info ------")
        print("FPS: {:.1f}".format(clock.fps()))
        print(f"Detected objects: {num_objects}")
    
        if detected_objects:
            print("\nDetected Objects:")
            for i, obj in enumerate(detected_objects, 1):
                x, y, w, h = obj['position']
                cx, cy = obj['center']
                print(f"Object {i}:")
                print(f"  Position: ({x}, {y}) Size: {w}x{h}")
                print(f"  Center: ({cx}, {cy})")
                print(f"  Detect Score: {obj['detect_score']:.2f}")
                print(f"  Class: {obj['class_str']} (Num: {obj['class_num']}, Dir: {obj['class_dir']})")
                print(f"  Class Confidence: {obj['class_confidence']:.1f}%")
        else:
            print("No objects detected")
    
        if final_number is not None:
            print(f"Final Output Number: {final_number}")
    
        print("-----------------------")
    
        # 显示帧率和最终结果
        fps_text = "FPS:{:.1f}".format(clock.fps())
        img.draw_string(5, 5, fps_text, color=(0, 0, 255), scale=2.0)
    
        if final_number is not None:
            result_text = "Num:{}".format(final_number)
            img.draw_string(img.width()//2-30, img.height()-30, result_text, color=(255, 0, 0), scale=2.0)
    

    您可能感兴趣的与本文相关的镜像

    Yolo-v5

    Yolo-v5

    Yolo

    YOLO(You Only Look Once)是一种流行的物体检测和图像分割模型,由华盛顿大学的Joseph Redmon 和Ali Farhadi 开发。 YOLO 于2015 年推出,因其高速和高精度而广受欢迎

    评论
    添加红包

    请填写红包祝福语或标题

    红包个数最小为10个

    红包金额最低5元

    当前余额3.43前往充值 >
    需支付:10.00
    成就一亿技术人!
    领取后你会自动成为博主和红包主的粉丝 规则
    hope_wisdom
    发出的红包
    实付
    使用余额支付
    点击重新获取
    扫码支付
    钱包余额 0

    抵扣说明:

    1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
    2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

    余额充值