mathAI:一位大佬使用LeNet5,使用CROHME数据集训练的数学公式识别模型。
最近需要做一个简单的四则运算的公式识别,也不需要识别什么复杂的公式,就小学能用到的程度就好了。一开始使用PaddleOCR来进行检测(ocr_det)识别(ocr_rec)(见下图),但是毕竟不是专门用来识别数学公式的,所以对于加号和除号等数学符号的识别效果不太好。所以希望自己训练一个简单的模型来对数字和数学符号进行识别,提高一点准确率。
上面那个项目使用的是LeNet5来实现,但是我想模型小一点,所以希望换MobileNets来实现试试。不过可以先用LeNet5实现一遍,熟悉一下流程。
CROHME数据集处理
第一个问题是,我第一次接触CROHME数据集,打开一看,都是*inkML后缀的标注文件,一脸懵逼,感觉光是数据集的转换提取就够我吃一壶了。
我要对图片里的数学字符进行单独识别,然后再组合起来。
所以就不需要对公式进行处理,只要提取公式里的字符作为训练集就行了。相当于一个简单的分类模型。
看网上数据提取的方式五花八门,我感觉只能自己试着按照自己的方式提取一下了。
按照字符将所有相同字符的存放到一个文件夹内这样,这样就省得每张图片都另外放一个文本进行标注了。
我使用的数据库是CROHME2013_data
,最后转换下来,有9万多的数据样本。
*.inkML文件内的结构:
<trace>
<trace id="18">
766 127, 766 127, 766 127, 766 127, 766 127, 764 127, 764 127, 764 127, 764 127, 764 127, 764 127, 763 127, 763 127, 763 127, 763 127, 763 127, 763 127, 763 127, 762 128, 762 128, 762 128, 762 128, 762 128, 763 127, 763 127, 763 127, 763 127, 765 127, 770 125, 770 125, 770 125, 770 125, 770 125, 777 124, 784 122, 784 122, 784 122, 784 122, 791 121, 791 121, 791 121, 797 120, 801 119, 801 119, 801 119, 803 119, 803 119, 803 119, 803 119, 803 119, 803 119, 803 119, 803 119, 803 119, 803 119, 803 119, 802 120, 802 120, 800 121
</trace>
<trace id="19">
768 123, 768 123, 768 123, 768 123, 768 123, 768 123, 768 123, 768 123, 768 123, 768 123, 768 123, 768 123, 768 123, 768 123, 768 123, 768 123, 768 124, 768 124, 768 124, 768 126, 768 126, 768 126, 768 126, 768 126, 768 130, 768 134, 768 134, 768 134, 768 134, 767 138, 767 138, 767 138, 766 143, 765 146, 765 146, 765 146, 765 148, 765 148, 765 148, 765 149, 765 149, 765 149, 765 149, 765 149, 765 149, 765 149, 765 149, 764 151, 764 151, 764 151, 764 151, 764 151, 764 151, 764 151, 764 151, 764 151, 764 151, 764 151, 764 149, 764 149, 764 149, 764 149, 766 148, 766 148, 766 148, 768 146, 768 146, 768 146, 768 146, 768 146, 768 146, 768 146, 770 144, 772 144, 772 144, 774 143, 774 143, 774 143, 774 143, 775 143, 776 143, 776 143, 776 143, 777 143, 777 143, 777 143, 780 144, 780 144, 780 144, 780 144, 781 144, 781 144, 781 144, 781 144, 781 144, 783 145, 783 145, 783 145, 783 145, 785 145, 786 147, 786 147, 786 147, 786 147, 786 147, 788 148, 788 148, 788 148, 790 149, 790 149, 790 150, 790 150, 790 150, 790 150, 790 150, 790 150, 790 153, 790 153, 790 153, 790 153, 790 153, 790 153, 789 155, 789 155, 789 155, 789 156, 789 156, 789 156, 789 156, 789 156, 789 157, 789 157, 789 157, 789 157, 789 159, 789 159, 789 159, 788 160, 788 160, 788 160, 788 160, 788 160, 788 160, 786 161, 786 161, 786 161, 786 161, 786 161, 786 163, 786 163, 786 163, 785 164, 785 164, 784 165, 784 165, 784 165, 784 165, 784 165, 782 166, 782 166, 782 166, 782 166, 782 166, 780 167, 780 167, 780 167, 779 168, 779 168, 779 168, 779 168, 777 168, 777 168, 777 168, 777 168, 775 168, 775 168, 775 168, 773 168, 773 168, 773 168, 773 168, 773 168, 773 168, 773 168, 773 168, 771 169, 771 169, 771 169, 771 169, 769 169, 767 169, 767 169, 767 169, 767 169, 767 169, 766 169, 766 169, 766 169, 766 169, 765 168, 765 168, 765 168, 765 168, 765 168, 765 168, 765 168, 765 168, 765 168, 765 168, 765 168, 765 168, 765 168, 764 168, 764 168, 764 168, 764 168, 764 168, 764 168, 764 168, 764 168, 764 168, 764 168, 762 168, 762 168, 762 168, 762 168, 762 168, 762 168, 762 168, 762 168, 762 168, 762 168, 762 168, 762 168, 762 168, 761 168, 761 168, 761 168, 761 168, 761 168, 761 168, 761 168, 761 168, 761 168, 761 168, 761 168, 761 168, 761 168, 761 168, 761 168, 762 168
</trace>
<trace><\trace>
:这个标签里面的像素点是笔迹。需要一定的处理,才能绘制出图像。每个trace只表示一笔,就是类似x
这个单词,理论上放在一个trace
标签内表示也没问题,但是还是将两笔分成了两个trace
来保存。
traceGroup
<traceGroup xml:id="29">
<annotation type="truth">5</annotation>
<traceView traceDataRef="18"/>
<traceView traceDataRef="19"/>
<annotationXML href="5_1"/>
</traceGroup>
这个<traceGroup>
会将属于一个数学字符的笔画放在一起,组成一个完整的数学字符。类似5
这个数字,就由两笔组成。
traceDataRef
这个属性就是前面trace
的id
,我们利用这个信息,来建立标注和图片之间的联系。
为啥要这样来标注数据呢?我估计是有些别的的类似cos
,log
之类数学运算符,是由几个单词组成的,但是需要作为一个整体的数学字符来识别。
根据像素点生成图片
首先肯定要检查traceGroup
,要知道哪些笔迹traceDataRef
是属于一组的。然后再根据这个,去trace里找对应的笔画,组合在一起,成为一个字符。
相当于将原本一个inkML文件里的公式里的数学字符都分离独立,作为不同的字符的数据集。
写了个由inkml文件生成图片然后保存到对应分类文件夹的。
在转换的过程中出现了挺多问题的,首先,CROHME数据集应该是多个手写数据集的组合,里面的笔迹trace
的尺度都不太一样,一开始我遇到的大概是200~300范围的,然后开了几个压缩包,又看到10000左右的,还看到小数表示的,甚至有几个文件是负数。我感觉被玩弄于鼓掌之中
而且我最想要的小学用的除号\div
,也只有其中一个数据集(HAMEX)内有一些,其他大部分都是针对高等数学的感觉。
而且转换过程中的一些参数设定感觉也很有讲究,因为我先画点,然后用膨胀连接,感觉一开始定义的画布大小,坐标尺度,膨胀的核尺寸,都能对生成的数据图片的效果产生影响:
上面有的就生成得很好,下面的部分的点就有点离散。
不知道还能怎么处理来将点平滑的连接起来连起来?啊,经过一个小时,终于解决了,也明白数据集为啥要通过笔迹组合字符的良苦用心。就是trace
里面的点都是按照笔画顺序进行排列的,只要根据前后关系连线就好了,这样就能得到比较连续的图片了。也不需要进行闭运算了。
我用的cv2.line
来连线,看起来效果还不错。
inkml转img的代码我写得太糟糕了,等过段时间优化之后再放上来。
todo:还需要对小于一定尺寸(20 x 20
)像素的图片去除。
接下来终于可以开始准备网络模型了。
数据处理源代码
import os
from uuid import uuid4
from typing import Union, Dict, List
import cv2
import numpy as np
import xml.etree.ElementTree as ET
def load_dir(file_dir: str) -> Union[None]:
for file_name in os.listdir(file_dir):
suffix = file_name.split('.')[-1]
if suffix != 'inkml':
return
file_path = os.path.join(file_dir, file_name)
tracegroup_anno = parse_inkml_file(file_path)
merge_trace(tracegroup_anno, file_path)
def match_key(obj, keyword: str) -> bool:
return True if obj.tag.split('}')[1] == keyword else False
def parse_inkml_file(file_path: str) -> Dict[str, List]:
root = ET.parse(file_path).getroot()
anno = {}
for items in root:
if not match_key(items, 'traceGroup'):
continue
for item in items:
if not match_key(item, 'traceGroup'):
continue
trace_ids = [int(id.get('traceDataRef')) for id in item if match_key(id, 'traceView')]
annotation = [id.text for id in item if match_key(id, 'annotation') and id.get('type') == 'truth'][0]
anno[annotation] = trace_ids
return anno
def merge_trace(trace_anno: Dict, file_path: str):
special_sign = {chr(124): 'abs', '\\': ''}
root = ET.parse(file_path).getroot()
for word, trace_ids in trace_anno.items():
# path
word = word.translate(str.maketrans(special_sign))
output_dir = os.path.join('output', word)
os.makedirs(output_dir, exist_ok=True)
min_x, max_x = float('inf'), -float('inf')
min_y, max_y = float('inf'), -float('inf')
img = np.zeros((3500, 3500))
for trace_id in trace_ids:
for item in root:
if not match_key(item, 'trace') or not int(item.get('id')) == trace_id:
continue
processed_pixel_loc = item.text.replace('\n', '_').replace(' ', '_').split(',')
pixel_last = None
for f in processed_pixel_loc:
parts = f.split('_')
pixel = [round(float(k) * 100) for k in parts if k != ''] if '.' in f else [int(k) for k in parts if k.replace("-", "").isnumeric()]
if pixel_last is not None:
cv2.line(img, (pixel_last[0], pixel_last[1]), (pixel[0], pixel[1]), (255, 255, 255), 1)
pixel_last = pixel
min_x, max_x = min(min_x, pixel[0]), max(max_x, pixel[0])
min_y, max_y = min(min_y, pixel[1]), max(max_y, pixel[1])
img[pixel[1], pixel[0]] = 255
img = img_postprecess(img[min_y: max_y, min_x: max_x])
save(output_dir, img)
# display(img)
def img_postprecess(img: np.ndarray) -> np.ndarray:
img = np.pad(img, ((5, 5), (5, 5)), mode='constant')
img = cv2.dilate(img, np.ones((2, 2), np.uint8)) # 膨胀处理
img = 255 - img
return img
def save(save_dir: str, img: np.ndarray):
img_name = f"{str(uuid4())}.jpg"
img_path = os.path.join(save_dir, img_name)
cv2.imwrite(img_path, img)
def display(img: np.ndarray):
cv2.namedWindow('img', cv2.WINDOW_NORMAL)
cv2.imshow('img', img)
cv2.waitKey()
if __name__ == '__main__':
file = f"resource"
load_dir(file)