EAST: An Efficient and Accurate Scene Text Detector 自然场景下的文字识别(原理及代码理解)

本文介绍了EAST模型,一种高效准确的自然场景文本检测管道。该模型通过直接预测图像中的文字区域,简化了文本检测流程。文章详细阐述了EAST的网络结构、特征提取、特征融合以及输出层等内容。

最近在学习自然场景下的文字识别,有一个比较新的模型EAST,所以学习一下。

论文原地址:https://arxiv.org/abs/1704.03155v2
源码:https://github.com/argman/EAST

模型特点及优势

该模型直接预测全图像中任意方向和四边形形状的单词或文本行,消除不必要的中间步骤(例如,候选聚合和单词分割)。通过下图它与一些其他方式的步骤对比,可以发现该模型的步骤比较简单,去除了中间一些复杂的步骤,所以符合它的特点,EAST, since it is an Efficient and Accuracy Scene Text detection pipeline.

这里写图片描述

网络结构

这里写图片描述

第一部分:Feature extractor stem(PVANet)

利用Inception的思想,即不同尺寸的卷积核的组合可以适应多尺度目标的检测,作者在这里采用PVANet模型,提取不同尺寸卷积核下的特征并用于后期的特征组合。
代码描述:
# 网络结构1:首先是一个resnet_v1_50网络

with slim.arg_scope(resnet_v1.resnet_arg_scope(weight_decay=weight_decay)):
    logits, end_points = resnet_v1.resnet_v1_50(images, is_training=is_training, scope='resnet_v1_50')
with tf.variable_scope('feature_fusion', values=[end_points.values]):
    batch_norm_params = {
    'decay': 0.997,
    'epsilon': 1e-5,
    'scale': True,
    'is_training': is_training
    }
    with slim.arg_scope([slim.conv2d],
                        activation_fn=tf.nn.relu, # 激活函数是relu
                        normalizer_fn=slim.batch_norm,
                        normalizer_params=batch_norm_params,
                        weights_regularizer=slim.l2_regularizer(weight_decay)): # L2正则
        f = [end_points['pool5'], end_points['pool4'],
             end_points['pool3'], end_points['pool2']]
第二部分:Feature-merging branch

在这一部分用来组合特征,并通过上池化和concat恢复到原图的尺寸,这里借鉴的是U-net的思想。
所谓上池化一般是指最大池化的逆过程,实际上是不能实现的但是,可以通过只把池化过程中最大激活值所在的位置激活,其他位置设为0,完成上池化的近似过程。
g和h的计算过程如下图所示。
这里写图片描述
代码描述:

         for i in range(4):
                print('Shape of f_{} {}'.format(i, f[i].shape))
            g = [None, None, None, None]
            h = [None, None, None, None]
            num_outputs = [None, 128, 64, 32]
            for i in range(4):
                if i == 0:
                    h[i] = f[i]  # 计算h
                else:
                    c1_1 = slim.conv2d(tf.concat([g[i-1], f[i]], axis=-1), num_outputs[i], 1)
                    h[i] = slim.conv2d(c1_1, num_outputs[i], 3)
                if i <= 2:
                    g[i] = unpool(h[i]) # 计算g
                else:
                    g[i] = slim.conv2d(h[i], num_outputs[i], 3)
                print('Shape of h_{} {}, g_{} {}'.format(i, h[i].shape, i, g[i].shape))
第三部分:Output Layer

上一部分的输出通过一个(1x1,1)的卷积核获得score_map。score_map与原图尺寸一致,每一个值代表此处是否有文字的可能性。
上一部分的输出通过一个(1x1,4)的卷积核获得RBOX 的geometry_map。有四个通道,分别代表每个像素点到文本矩形框上,右,底,左边界的距离。另外再通过一个(1x1, 1)的卷积核获得该框的旋转角,这是为了能够识别出有旋转的文字。
上一部分的输出通过一个(1x1,8)的卷积核获得QUAD的geometry_map,八个通道分别代表每个像素点到任意四边形的四个顶点的距离。
具体如下图所示:
这里写图片描述
代码描述:

      # 计算score_map
            F_score = slim.conv2d(g[3], 1, 1, activation_fn=tf.nn.sigmoid, normalizer_fn=None)
            # 4 channel of axis aligned bbox and 1 channel rotation angle
            # 计算RBOX的geometry_map
            geo_map = slim.conv2d(g[3], 4, 1, activation_fn=tf.nn.sigmoid, normalizer_fn=None) * FLAGS.text_scale
            # angle is between [-45, 45] #计算angle
            angle_map = (slim.conv2d(g[3], 1, 1, activation_fn=tf.nn.sigmoid, normalizer_fn=None) - 0.5) * np.pi/2
            F_geometry = tf.concat([geo_map, angle_map], axis=-1)

代价函数

代价函数分两部分,如下,第一部分是分类误差,第二部分是几何误差,文中权衡重要性,λg=1。
这里写图片描述

分类误差函数

采用 class-balanced cross-entropy,这样做可以很实用的处理正负样本不均衡的问题。
这里写图片描述
其中:
这里写图片描述
即β=反例样本数量/总样本数量 (balance factor)
代码描述:

# 计算score map的loss
def dice_coefficient(y_true_cls, y_pred_cls,
                     training_mask):
    '''
    dice loss
    :param y_true_cls:
    :param y_pred_cls:
    :param training_mask:
    :return:
    '''
    eps = 1e-5
    intersection = tf.reduce_sum(y_true_cls * y_pred_cls * training_mask)
    union = tf.reduce_sum(y_true_cls * training_mask) + tf.reduce_sum(y_pred_cls * training_mask) + eps
    loss = 1. - (2 * intersection / union)
    tf.summary.scalar('classification_dice_loss', loss)
    return loss
几何误差函数
  1. 对于RBOX,采用IoU loss
    这里写图片描述
    角度误差则为:
    这里写图片描述
  2. 对于QUAD采用smoothed L1 loss
    CQ={x1,y1,x2,y2,x3,y3,x4,y4}
    这里写图片描述
    NQ*指的是四边形最短边的长度
    代码描述:
def loss(y_true_cls, y_pred_cls,
         y_true_geo, y_pred_geo,
         training_mask):
    '''
    define the loss used for training, contraning two part,
    the first part we use dice loss instead of weighted logloss,
    the second part is the iou loss defined in the paper
    :param y_true_cls: ground truth of text
    :param y_pred_cls: prediction os text
    :param y_true_geo: ground truth of geometry
    :param y_pred_geo: prediction of geometry
    :param training_mask: mask used in training, to ignore some text annotated by ###
    :return:
    '''
    classification_loss = dice_coefficient(y_true_cls, y_pred_cls, training_mask)
    # scale classification loss to match the iou loss part
    classification_loss *= 0.01

    # d1 -> top, d2->right, d3->bottom, d4->left
    d1_gt, d2_gt, d3_gt, d4_gt, theta_gt = tf.split(value=y_true_geo, num_or_size_splits=5, axis=3)
    d1_pred, d2_pred, d3_pred, d4_pred, theta_pred = tf.split(value=y_pred_geo, num_or_size_splits=5, axis=3)
    area_gt = (d1_gt + d3_gt) * (d2_gt + d4_gt)
    area_pred = (d1_pred + d3_pred) * (d2_pred + d4_pred)
    w_union = tf.minimum(d2_gt, d2_pred) + tf.minimum(d4_gt, d4_pred)
    h_union = tf.minimum(d1_gt, d1_pred) + tf.minimum(d3_gt, d3_pred)
    area_intersect = w_union * h_union  #计算R_true与R_pred的交集
    area_union = area_gt + area_pred - area_intersect  #计算R_true与R_pred的并集
    L_AABB = -tf.log((area_intersect + 1.0)/(area_union + 1.0)) # IoU loss,加1为了防止交集为0,log0没意义
    L_theta = 1 - tf.cos(theta_pred - theta_gt) # 夹角的loss
    tf.summary.scalar('geometry_AABB', tf.reduce_mean(L_AABB * y_true_cls * training_mask))
    tf.summary.scalar('geometry_theta', tf.reduce_mean(L_theta * y_true_cls * training_mask))
    L_g = L_AABB + 20 * L_theta # geometry_map loss

    return tf.reduce_mean(L_g * y_true_cls * training_mask) + classification_loss
采用NMS进行几何过滤

在假设来自附近像素的几何图形倾向于高度相关的情况下,逐行合并几何图形,并且在合并同一行中的几何图形时将迭代合并当前遇到的几何图形。

测试结果:
这里写图片描述
这里写图片描述

缺点:
由于感受野不够大,所以对于较长的文字比较难识别。

评论 9
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值