除了传统目标检测方法yolov系列,ssd,faster-rnnn等之外,基于像素分割的Mask-RCNN网络也可以做目标检测,尤其最近在做版面分析,看到有人用mask-rcnn做票据识别的效果不错,可以准确定位出票据上面的关键信息点,于是特意研究了一番。
mask-rcnn的GitHub地址:https://github.com/matterport/Mask_RCNN
需求说明:我有火车票和票据数据一共20张(自己用手机拍摄的),我现在想提取这些票据上面我关心的关键信息条目。
解决办法:使用mask-rcnn解决
第一步:数据准备,使用labelme标注我的数据,GitHub地址https://github.com/wkentaro/labelme#anaconda,我是mac机,所以按照mac机的安装方法执行了两条命令就装好了.安装完成后打开命令提示符,执行labelme即可出现如下界面
此为标注界面,点击左侧的“openDir"按钮,选择自己的数据文件夹目录,然后开始标注。标注的时候我只标注了我感兴趣的条目,并分别分类为1-7之间。
标注完后再数据目录会出现和图片命名一样的json文件,将这些json文件单独放到一个文件夹,然后编辑一个test.sh的文件,在里面写入如下内容:
#!/bin/bash
for((i=33;i<57;i++))
do
labelme_json_to_dataset /Users/yjf/work/datasets/labelme/WechatIMG124${i}.json
done
该代码为将json文件转换为可以训练的文件数据。注意改一下第二行和第四行,我的数据有编号,所以这样写的,大家根据自己的写,总之让第四行的第二个参数为自己的json文件就行。然后在命令行以此执行chmod 777 test.sh 和sh test.sh即可开始转换。转换后会生成多个文件夹,文件夹里面存放的就是标注结果,每个文件夹有5个文件,有原图,标注结果等,大家标注完了自己看就行。然后将这些生成的文件夹单独存放在个文件夹中,此文件夹即为等会训练要用的文件夹了,我的目录为/Users/yjf/work/datasets/labelme, 注意,如果目录中有.DS_Store,记得删掉。
第二步:下载mask-rcnn:GitHub地址:https://github.com/matterport/Mask_RCNN,下载完后再下载coco预训练的模型https://github.com/matterport/Mask_RCNN/releases/download/v2.0/mask_rcnn_coco.h5,将其存放到下载好的Mask_RCNN目录中根目录下(和mrcnn在同一个目录下)。
第三步:修改网络。在sample目录下新建yjf_test目录,我在该目录下新建了一个railway_test.py的文件,里面的内容如下:
import os
import sys
import random
import yaml
import math
import re
import time
import numpy as np
import cv2
from PIL import Image
import matplotlib
import matplotlib.pyplot as plt
# Root directory of the project
ROOT_DIR = os.path.abspath("./")
# Import Mask RCNN
sys.path.append(ROOT_DIR) # To find local version of the library
from mrcnn.config import Config
from mrcnn import utils
import mrcnn.model as modellib
from mrcnn import visualize
from mrcnn.model import log
# %matplotlib inline
# Directory to save logs and trained model
MODEL_DIR = os.path.join(ROOT_DIR, "logs_")
# Local path to trained weights file
COCO_MODEL_PATH = os.path.join(ROOT_DIR, "mask_rcnn_coco.h5")
# Download COCO trained weights from Releases if needed
if not os.path.exists(COCO_MODEL_PATH):
utils.download_trained_weights(COCO_MODEL_PATH)
############################################################
# Configurations
############################################################
class ShapesConfig(Config):
"""Configuration for training on the toy shapes dataset.
Derives from the base Config class and overrides values specific
to the toy shapes dataset.
"""
# Give the configuration a recognizable name
NAME = "shapes"
# Train on 1 GPU and 8 images per GPU. We can put multiple images on each
# GPU because the images are small. Ba