EAST: An Efficient and Accurate Scene Text Detector 自然场景下的文字识别(原理及代码理解)

本文介绍了EAST模型,一种高效准确的自然场景文本检测管道。该模型通过直接预测图像中的文字区域,简化了文本检测流程。文章详细阐述了EAST的网络结构、特征提取、特征融合以及输出层等内容。

最近在学习自然场景下的文字识别,有一个比较新的模型EAST,所以学习一下。

论文原地址:https://arxiv.org/abs/1704.03155v2
源码:https://github.com/argman/EAST

模型特点及优势

该模型直接预测全图像中任意方向和四边形形状的单词或文本行,消除不必要的中间步骤(例如,候选聚合和单词分割)。通过下图它与一些其他方式的步骤对比,可以发现该模型的步骤比较简单,去除了中间一些复杂的步骤,所以符合它的特点,EAST, since it is an Efficient and Accuracy Scene Text detection pipeline.

这里写图片描述

网络结构

这里写图片描述

第一部分:Feature extractor stem(PVANet)

利用Inception的思想,即不同尺寸的卷积核的组合可以适应多尺度目标的检测,作者在这里采用PVANet模型,提取不同尺寸卷积核下的特征并用于后期的特征组合。
代码描述:
# 网络结构1:首先是一个resnet_v1_50网络

with slim.arg_scope(resnet_v1.resnet_arg_scope(weight_decay=weight_decay)):
    logits, end_points = resnet_v1.resnet_v1_50(images, is_training=is_training, scope='resnet_v1_50')
with tf.variable_scope('feature_fusion', values=[end_points.values]):
    batch_norm_params = {
    'decay': 0.997,
    'epsilon': 1e-5,
    'scale': True,
    'is_training': is_training
    }
    with slim.arg_scope([slim.conv2d],
                        activation_fn=tf.nn.relu, # 激活函数是relu
                        normalizer_fn=slim.batch_norm,
                        normalizer_params=batch_norm_params,
                        weights_regularizer=slim.l2_regularizer(weight_decay)): # L2正则
        f = [end_points['pool5'], end_points['pool4'],
             end_points['pool3'], end_points['pool2']]
第二部分:Feature-merging branch

在这一部分用来组合特征,并通过上池化和concat恢复到原图的尺寸,这里借鉴的是U-net的思想。
所谓上池化一般是指最大池化的逆过程,实际上是不能实现的但是,可以通过只把池化过程中最大激活值所在的位置激活,其他位置设为0,完成上池化的近似过程。
g和h的计算过程如下图所示。
这里写图片描述
代码描述:

         for i in range(4):
                print('Shape of f_{} {}'.format(i, f[i].shape))
            g = [None, None, None, None]
            h = [None, None, None, None]
            num_outputs = [None, 128, 64, 32]
            for i in range(4):
                if i == 0:
                    h[i] = f[i]  # 计算h
                else:
                    c1_1 = slim.conv2d(tf.concat([g[i-1], f[i]], axis=-1), num_outputs[i], 1)
                    h[i] = slim.conv2d(c1_1, num_outputs[i], 3)
                if i <= 2:
                    g[i] = unpool(h[i]) # 计算g
                else:
                    g[i] = slim.conv2d(h[i], num_outputs[i], 3)
                print('Shape of h_{} {}, g_{} {}'.format(i, h[i].shape, i, g[i].shape))
第三部分:Output Layer

上一部分的输出通过一个(1x1,1)的卷积核获得score_map。score_map与原图尺寸一致,每一个值代表此处是否有文字的可能性。
上一部分的输出通过一个(1x1,4)的卷积核获得RBOX 的geometry_map。有四个通道,分别代表每个像素点到文本矩形框上,右,底,左边界的距离。另外再通过一个(1x1, 1)的卷积核获得该框的旋转角,这是为了能够识别出有旋转的文字。
上一部分的输出通过一个(1x1,8)的卷积核获得QUAD的geometry_map,八个通道分别代表每个像素点到任意四边形的四个顶点的距离。
具体如下图所示:
这里写图片描述
代码描述:

      # 计算score_map
            F_score = slim.conv2d(g[3], 1, 1, activation_fn=tf.nn.sigmoid, normalizer_fn=None)
            # 4 channel of axis aligned bbox and 1 channel rotation angle
            # 计算RBOX的geometry_map
            geo_map = slim.conv2d(g[3], 4, 1, activation_fn=tf.nn.sigmoid, normalizer_fn=None) * FLAGS.text_scale
            # angle is between [-45, 45] #计算angle
            angle_map = (slim.conv2d(g[3], 1, 1, activation_fn=tf.nn.sigmoid, normalizer_fn=None) - 0.5) * np.pi/2
            F_geometry = tf.concat([geo_map, angle_map], axis=-1)

代价函数

代价函数分两部分,如下,第一部分是分类误差,第二部分是几何误差,文中权衡重要性,λg=1。
这里写图片描述

分类误差函数

采用 class-balanced cross-entropy,这样做可以很实用的处理正负样本不均衡的问题。
这里写图片描述
其中:
这里写图片描述
即β=反例样本数量/总样本数量 (balance factor)
代码描述:

# 计算score map的loss
def dice_coefficient(y_true_cls, y_pred_cls,
                     training_mask):
    '''
    dice loss
    :param y_true_cls:
    :param y_pred_cls:
    :param training_mask:
    :return:
    '''
    eps = 1e-5
    intersection = tf.reduce_sum(y_true_cls * y_pred_cls * training_mask)
    union = tf.reduce_sum(y_true_cls * training_mask) + tf.reduce_sum(y_pred_cls * training_mask) + eps
    loss = 1. - (2 * intersection / union)
    tf.summary.scalar('classification_dice_loss', loss)
    return loss
几何误差函数
  1. 对于RBOX,采用IoU loss
    这里写图片描述
    角度误差则为:
    这里写图片描述
  2. 对于QUAD采用smoothed L1 loss
    CQ={x1,y1,x2,y2,x3,y3,x4,y4}
    这里写图片描述
    NQ*指的是四边形最短边的长度
    代码描述:
def loss(y_true_cls, y_pred_cls,
         y_true_geo, y_pred_geo,
         training_mask):
    '''
    define the loss used for training, contraning two part,
    the first part we use dice loss instead of weighted logloss,
    the second part is the iou loss defined in the paper
    :param y_true_cls: ground truth of text
    :param y_pred_cls: prediction os text
    :param y_true_geo: ground truth of geometry
    :param y_pred_geo: prediction of geometry
    :param training_mask: mask used in training, to ignore some text annotated by ###
    :return:
    '''
    classification_loss = dice_coefficient(y_true_cls, y_pred_cls, training_mask)
    # scale classification loss to match the iou loss part
    classification_loss *= 0.01

    # d1 -> top, d2->right, d3->bottom, d4->left
    d1_gt, d2_gt, d3_gt, d4_gt, theta_gt = tf.split(value=y_true_geo, num_or_size_splits=5, axis=3)
    d1_pred, d2_pred, d3_pred, d4_pred, theta_pred = tf.split(value=y_pred_geo, num_or_size_splits=5, axis=3)
    area_gt = (d1_gt + d3_gt) * (d2_gt + d4_gt)
    area_pred = (d1_pred + d3_pred) * (d2_pred + d4_pred)
    w_union = tf.minimum(d2_gt, d2_pred) + tf.minimum(d4_gt, d4_pred)
    h_union = tf.minimum(d1_gt, d1_pred) + tf.minimum(d3_gt, d3_pred)
    area_intersect = w_union * h_union  #计算R_true与R_pred的交集
    area_union = area_gt + area_pred - area_intersect  #计算R_true与R_pred的并集
    L_AABB = -tf.log((area_intersect + 1.0)/(area_union + 1.0)) # IoU loss,加1为了防止交集为0,log0没意义
    L_theta = 1 - tf.cos(theta_pred - theta_gt) # 夹角的loss
    tf.summary.scalar('geometry_AABB', tf.reduce_mean(L_AABB * y_true_cls * training_mask))
    tf.summary.scalar('geometry_theta', tf.reduce_mean(L_theta * y_true_cls * training_mask))
    L_g = L_AABB + 20 * L_theta # geometry_map loss

    return tf.reduce_mean(L_g * y_true_cls * training_mask) + classification_loss
采用NMS进行几何过滤

在假设来自附近像素的几何图形倾向于高度相关的情况下,逐行合并几何图形,并且在合并同一行中的几何图形时将迭代合并当前遇到的几何图形。

测试结果:
这里写图片描述
这里写图片描述

缺点:
由于感受野不够大,所以对于较长的文字比较难识别。

### 高效且准确的场景文本检测方法 #### 1. EAST (Efficient and Accurate Scene Text Detector) EAST 是一种端到端的文字检测算法,能够直接预测文字边界框。该模型设计简单有效,在保持较高准确性的同时实现了较快的速度。 - **特点**: - 可以处理任意形状的文字实例。 - 使用全卷积网络结构,无需复杂的预处理或后处理操作。 ```python import cv2 import numpy as np def east_text_detection(image_path): net = cv2.dnn.readNet('frozen_east_text_detection.pb') image = cv2.imread(image_path) orig = image.copy() (H, W) = image.shape[:2] layerNames = [ "feature_fusion/Conv_7/Sigmoid", "feature_fusion/concat_3" ] blob = cv2.dnn.blobFromImage(image, 1.0, (W, H), (123.68, 116.78, 103.94), swapRB=True, crop=False) net.setInput(blob) (scores, geometry) = net.forward(layerNames) # 进一步处理 scores 和 geometry 来获取最终的结果... ``` [^1] #### 2. CRAFT (Character Region Awareness For Text Detection) CRAFT专注于字符级别的特征学习,可以更精确地捕捉不规则排列的文字区域。此方法特别适合于弯曲或者倾斜的文字行。 - **优势**: - 对抗变形能力强,适用于多种复杂背景下的文本识别任务。 - 提供了更加细致化的分割结果,有助于后续的文字识别工作。 ```python from craft import CRAFT import torch from PIL import Image model = CRAFT(pretrained=True).eval() image = Image.open('path_to_image.jpg').convert('RGB') with torch.no_grad(): y, feature = model(torch.unsqueeze(transform(image), dim=0)) # 后续可以根据y来绘制边框或者其他可视化操作 ``` [^2] #### 3. PSENet (Pixel Aggregation Segmentation Network for Arbitrary-Shaped Scene Text Detection) PSENet利用像素聚合的思想来进行任意形状文本的检测。它通过多尺度融合的方式增强了对于细长物体的理解能力,并且能够在不同的比例尺下稳定表现。 - **亮点**: - 支持自由形态的文字轮廓提取。 - 结合了全局上下文信息和局部细节特性,提高了鲁棒性和泛化性。 ```python from pse import decode as pse_decode import torchvision.transforms.functional as F img_tensor = F.to_tensor(Image.open(img_file)).unsqueeze(0) preds = model(img_tensor)[-1].detach().cpu().numpy()[0] boxes = pse_decode(preds, min_area_threshold=500) ``` [^3]
评论 9
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值