基于CRNN的不定长文字识别原理与实现

本文详细介绍了CRNN在文字识别中的工作原理,包括如何将CNN和RNN结合,以及如何通过CTC算法处理不定长文本对齐问题。重点讲解了网络结构、训练过程和实际应用案例,展示了CRNN在OCR任务中的高效性能。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

基于CRNN的不定长文字识别

推荐阅读

CRNN文字识别原理

在以前的OCR任务中,识别过程分为两步:单字切割和分类任务。我们一般都会讲一连串文字的文本文件先利用投影法切割出单个字体,在送入CNN里进行文字分类。但是此法已经有点过时了,现在更流行的是基于深度学习的端到端的文字识别,即我们不需要显式加入文字切割这个环节,而是将文字识别转化为序列学习问题,虽然输入的图像尺度不同,文本长度不同,但是经过DCNN和RNN后,在输出阶段经过一定的翻译后,就可以对整个文本图像进行识别,也就是说,文字的切割也被融入到深度学习中去了。

现今基于深度学习的端到端OCR技术有两大主流技术:CRNN OCR和attention OCR。其实这两大方法主要区别在于最后的输出层(翻译层),即怎么将网络学习到的序列特征信息转化为最终的识别结果。这两大主流技术在其特征学习阶段都采用了CNN+RNN的网络结构,CRNN OCR在对齐时采取的方式是CTC算法,而attention OCR采取的方式则是attention机制。这里介绍应用更为广泛的CRNN算法。
在这里插入图片描述

网络结构包含三部分,从下到上依次为:

  1. 卷积层,使用CNN,作用是从输入图像中提取特征序列;
  2. 循环层,使用RNN,作用是预测从卷积层获取的特征序列的标签(真实值)分布;
  3. 转录层,使用CTC,作用是把从循环层获取的标签分布通过去重整合等操作转换成最终的识别结果;
    在这里插入图片描述

端到端OCR的难点在哪儿呢?在于怎么处理不定长序列对齐问题!CRNN OCR其实是借用了语音识别中解决不定长语音序列的思路。与语音识别问题类似,OCR可建模为时序依赖的词汇或者短语识别问题。基于联结时序分类(Connectionist Temporal Classification, CTC)训练RNN的算法,在语音识别领域显著超过传统语音识别算法。一些学者尝试把CTC损失函数借鉴到OCR识别中,CRNN 就是其中代表性算法。CRNN算法输入100*32归一化高度的词条图像,基于7层CNN(普遍使用VGG16)提取特征图,把特征图按列切分(Map-to-Sequence),每一列的512维特征,输入到两层各256单元的双向LSTM进行分类。在训练过程中,通过CTC损失函数的指导,实现字符位置与类标的近似软对齐。

CRNN借鉴了语音识别中的LSTM+CTC的建模方法,不同点是输入进LSTM的特征,从语音领域的声学特征(MFCC等),替换为CNN网络提取的图像特征向量。CRNN算法最大的贡献,是把CNN做图像特征工程的潜力与LSTM做序列化识别的潜力,进行结合。它既提取了鲁棒特征,又通过序列识别避免了传统算法中难度极高的单字符切分与单字符识别,同时序列化识别也嵌入时序依赖(隐含利用语料)。在训练阶段,CRNN将训练图像统一缩放100×32(w × h);在测试阶段,针对字符拉伸导致识别率降低的问题,CRNN保持输入图像尺寸比例,但是图像高度还是必须统一为32个像素,卷积特征图的尺寸动态决定LSTM时序长度。这里举个例子

现在输入有个图像,为了将特征输入到Recurrent Layers,做如下处理:

  • 首先会将图像缩放到 32×W×1 大小
  • 然后经过CNN后变为 1×(W/4)× 512
  • 接着针对LSTM,设置 T=(W/4) , D=512 ,即可将特征输入LSTM。
  • LSTM有256个隐藏节点,经过LSTM后变为长度为T × nclass的向量,再经过softmax处理,列向量每个元素代表对应的字符预测概率,最后再将这个T的预测结果去冗余合并成一个完整识别结果即可。
    在这里插入图片描述
    CRNN中需要解决的问题是图像文本长度是不定长的,所以会存在一个对齐解码的问题,所以RNN需要一个额外的搭档来解决这个问题,这个搭档就是著名的CTC解码。
    CRNN采取的架构是CNN+RNN+CTC,cnn提取图像像素特征,rnn提取图像时序特征,而ctc归纳字符间的连接特性。

那么CTC有什么好处?因手写字符的随机性,人工可以标注字符出现的像素范围,但是太过麻烦,ctc可以告诉我们哪些像素范围对应的字符:
在这里插入图片描述
我们知道,CRNN中RNN层输出的一个不定长的序列,比如原始图像宽度为W,可能其经过CNN和RNN后输出的序列个数为S,此时我们要将该序列翻译成最终的识别结果。RNN进行时序分类时,不可避免地会出现很多冗余信息,比如一个字母被连续识别两次,这就需要一套去冗余机制,但是简单地看到两个连续字母就去冗余的方法也有问题,比如cook,geek一类的词,所以CTC有一个blank机制来解决这个问题。这里举个例子来说明。
在这里插入图片描述
如上图所示,我们要识别这个手写体图像,标签为“ab”,经过CNN+RNN学习后输出序列向量长度为5,即t0~t4,此时我们要将该序列翻译为最后的识别结果。我们在翻译时遇到的第一个难题就是,5个序列怎么转化为对应的两个字母?重复的序列怎么解决?刚好位于字与字之间的空白的序列怎么映射?这些都是CTC需要解决的问题。

我们从肉眼可以看到,t0,t1,t2时刻都应映射为“a”,t3,t4时刻都应映射为“b”。如果我们将连续重复的字符合并成一个输出的话,即“aaabb”将被合并成“ab”输出。但是这样子的合并机制是有问题的,比如我们的标签图像时“aab”时,我们的序列输出将可能会是“aaaaaaabb”,这样子我们就没办法确定该文本应被识别为“aab”还是“ab”。CTC为了解决这种二义性,提出了插入blank机制,比如我们以“-”符号代表blank,则若标签为“aaa-aaaabb”则将被映射为“aab”,而“aaaaaaabb”将被映射为“ab”。引入blank机制,我们就可以很好地处理了重复字符的问题了。

但我们还注意到,“aaa-aaaabb”可以映射为“aab”,同样地,“aa-aaaaabb”也可以映射为“aab”,也就是说,存在多个不同的字符组合可以映射为“aab”,更总结地说,一个标签存在一条或多条的路径。比如下面“state”这个例子,也存在多条不同路径映射为"state":
在这里插入图片描述
上面提到,RNN层输出的是序列中概率矩阵,那么

p ( π = − − s t t a − t − − − e ∣ x , S ) = ∏ t = 1 T y π t t = ( y − 1 ) × ( y − 2 ) × ( y s 3 ) × ( y t 4 ) × ( y t 5 ) × ( y a 6 ) × ( y − 7 ) × ( y t 8 ) × ( y − 9 ) × ( y − 10 ) × ( y − 11 ) × ( y e 12 ) p(\pi=--stta-t---e|x,S)=\prod_{t=1}^{T}y_{\pi_{t}}^{t}=(y_{-}^{1})\times(y_{-}^{2})\times(y_{s}^{3})\times(y_{t}^{4})\times(y_{t}^{5})\times(y_{a}^{6})\times(y_{-}^{7})\times(y_{t}^{8})\times(y_{-}^{9})\times(y_{-}^{10})\times(y_{-}^{11})\times(y_{e}^{12}) p(π=sttatex,S)=t=1Tyπtt=(y1)×(y2)×(ys3)×(yt4)×(yt5)×(ya6)×(y7)×(yt8)×(y9)×(y10)×(y11)×(ye12)

其中, y − 1 y_{-}^{1} y1表示第一个序列输出“-”的概率,那么对于输出某条路径𝜋的概率为各个序列概率的乘积。所以要得到一个标签可以有多个路径来获得,从直观上理解就是,我们输出一张文本图像到网络中,我们需要使得输出为标签L的概率最大化,由于路径之间是互斥的,对于标注序列,其条件概率为所有映射到它的路径概率之和:
在这里插入图片描述
其中 π ∈ B − 1 ( l ) \pi\in B^{-1}(l) πB1(l)的意思是,所有可以合并成l的所有路径集合。

这种通过映射B和所有候选路径概率之和的方式使得CTC不需要对原始的输入序列进行准确的切分,这使得RNN层输出的序列长度>label长度的任务翻译变得可能。CTC可以与任意的RNN模型,但是考虑到标注概率与整个输入串有关,而不是仅与前面小窗口范围的片段相关,因此双向的RNN/LSTM模型更为适合。

ctc会计算loss ,从而找到最可能的像素区域对应的字符。事实上,这里loss的计算本质是对概率的归纳:
在这里插入图片描述
如上图,对于最简单的时序为2的(t0t1)的字符识别,可能的字符为“a”,“b”和“-”,颜色越亮代表概率越高。我们如果采取最大概率路径解码的方法,一看就是“–”的概率最大,真实字符为空即“”的概率为0.6*0.6=0.36。

但是我们忽略了一点,真实字符为“a”的概率不只是”aa” 即0.4*0.4 , 事实上,“aa”, “a-“和“-a”都是代表“a”,所以,输出“a”的概率为:

0.4*0.4 + 0.4 * 0.6 + 0.6*0.4 = 0.16+0.24+0.24 = 0.64

所以“a”的概率比空“”的概率高!可以看出,这个例子里最大概率路径和最大概率序列完全不同,所以CTC解码通常不适合采用最大概率路径的方法,而应该采用前缀搜索算法解码或者约束解码算法。

通过对概率的计算,就可以对之前的神经网络进行反向传播更新。类似普通的分类,CTC的损失函数O定义为负的最大似然,为了计算方便,对似然取对数。

O = − l n ( ∏ ( x , z ) ∈ S p ( l ∣ x ) ) = − ∑ ( x , z ) ∈ S l n p ( l ∣ x ) O=-ln(\prod_{(x,z)\in S} p(l|x))=-\sum_{(x,z)\in S}lnp(l|x) O=ln((x,z)Sp(lx))=(x,z)Slnp(lx)

我们的训练目标就是使得损失函数O优化得最小即可。

下面将着重讲解CRNN代码实现过程以及识别效果。

数据处理

利用图像处理技术我们手工大批量生成文字图像,一共360万张图像样本,效果如下:
在这里插入图片描述
我们划分了训练集和测试集(10:1),并单独存储为两个文本文件:
在这里插入图片描述
文本文件里的标签格式如下:
在这里插入图片描述
我们获取到的是最原始的数据集,在图像深度学习训练中我们一般都会把原始数据集转化为lmdb格式以方便后续的网络训练。因此我们也需要对该数据集进行lmdb格式转化。下面代码就是用于lmdb格式转化,思路比较简单,就是首先读入图像和对应的文本标签,先使用字典将该组合存储起来(cache),再利用lmdb包的put函数把字典(cache)存储的k,v写成lmdb格式存储好(cache当有了1000个元素就put一次)。

import lmdb
import cv2
import numpy as np
import os


def checkImageIsValid(imageBin):
    if imageBin is None:
        return False
    try:
        imageBuf = np.fromstring(imageBin, dtype=np.uint8)
        img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE)
        imgH, imgW = img.shape[0], img.shape[1]
    except:
        return False
    else:
        if imgH * imgW == 0:
            return False
    return True


def writeCache(env, cache):
    with env.begin(write=True) as txn:
        for k, v in cache.items():
            txn.put(k, v)


def createDataset(outputPath, imagePathList, labelList, lexiconList=None, checkValid=True):
    """
    Create LMDB dataset for CRNN training.
    ARGS:
        outputPath    : LMDB output path
        imagePathList : list of image path
        labelList     : list of corresponding groundtruth texts
        lexiconList   : (optional) list of lexicon lists
        checkValid    : if true, check the validity of every image
    """
    assert (len(imagePathList) == len(labelList))
    nSamples = len(imagePathList)
    env = lmdb.open(outputPath, map_size=1099511627776)
    cache = {}
    cnt = 1
    for i in range(nSamples):
        imagePath = ''.join(imagePathList[i]).split()[0].replace('\n', '').replace('\r\n', '')
        # print(imagePath)
        label = ''.join(labelList[i])
        print(label)
        # if not os.path.exists(imagePath):
        #     print('%s does not exist' % imagePath)
        #     continue

        with open('.' + imagePath, 'r') as f:
            imageBin = f.read()

        if checkValid:
            if not checkImageIsValid(imageBin):
                print('%s is not a valid image' % imagePath)
                continue
        imageKey = 'image-%09d' % cnt
        labelKey = 'label-%09d' % cnt
        cache[imageKey] = imageBin
        cache[labelKey] = label
        if lexiconList:
            lexiconKey = 'lexicon-%09d' % cnt
            cache[lexiconKey] = ' '.join(lexiconList[i])
        if cnt % 1000 == 0:
            writeCache(env, cache)
            cache = {}
            print('Written %d / %d' % (cnt, nSamples))
        cnt += 1
        print(cnt)
    nSamples = cnt - 1
    cache['num-samples'] = str(nSamples)
    writeCache(env, cache)
    print('Created dataset with %d samples' % nSamples)


OUT_PATH = '../crnn_train_lmdb'
IN_PATH = './train.txt'

if __name__ == '__main__':
    outputPath = OUT_PATH
    if not os.path.exists(OUT_PATH):
        os.mkdir(OUT_PATH)
    imgdata = open(IN_PATH)
    imagePathList = list(imgdata)

    labelList = []
    for line in imagePathList:
        word = line.split()[1]
        labelList.append(word)
    createDataset(outputPath, imagePathList, labelList)

我们运行上面的代码,可以得到训练集和测试集的lmdb
在这里插入图片描述
在数据准备部分还有一个操作需要强调的,那就是文字标签数字化,即我们用数字来表示每一个文字(汉字,英文字母,标点符号)。比如“我”字对应的id是1,“l”对应的id是1000,“?”对应的id是90,如此类推,这种编解码工作使用字典数据结构存储即可,训练时先把标签编码(encode),预测时就将网络输出结果解码(decode)成文字输出。

class strLabelConverter(object):
    """Convert between str and label.

    NOTE:
        Insert `blank` to the alphabet for CTC.

    Args:
        alphabet (str): set of the possible characters.
        ignore_case (bool, default=True): whether or not to ignore all of the case.
    """

    def __init__(self, alphabet, ignore_case=False):
        self._ignore_case = ignore_case
        if self._ignore_case:
            alphabet = alphabet.lower()
        self.alphabet = alphabet + '-'  # for `-1` index

        self.dict = {}
        for i, char in enumerate(alphabet):
            # NOTE: 0 is reserved for 'blank' required by wrap_ctc
            self.dict[char] = i + 1

    def encode(self, text):
        """Support batch or single str.

        Args:
            text (str or list of str): texts to convert.

        Returns:
            torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
            torch.IntTensor [n]: length of each text.
        """

        length = []
        result = []
        for item in text:
            item = item.decode('utf-8', 'strict')

            length.append(len(item))
            for char in item:

                index = self.dict[char]
                result.append(index)

        text = result
        # print(text,length)
        return (torch.IntTensor(text), torch.IntTensor(length))

    def decode(self, t, length, raw=False):
        """Decode encoded texts back into strs.

        Args:
            torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
            torch.IntTensor [n]: length of each text.

        Raises:
            AssertionError: when the texts and its length does not match.

        Returns:
            text (str or list of str): texts to convert.
        """
        if length.numel() == 1:
            length = length[0]
            assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(),
                                                                                                         length)
            if raw:
                return ''.join([self.alphabet[i - 1] for i in t])
            else:
                char_list = []
                for i in range(length):
                    if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])):
                        char_list.append(self.alphabet[t[i] - 1])
                return ''.join(char_list)
        else:
            # batch mode
            assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format(
                t.numel(), length.sum())
            texts = []
            index = 0
            for i in range(length.numel()):
                l = length[i]
                texts.append(
                    self.decode(
                        t[index:index + l], torch.IntTensor([l]), raw=raw))
                index += l
            return texts

网络设计

根据CRNN的论文描述,CRNN是由CNN-》RNN-》CTC三大部分架构而成,分别对应卷积层、循环层和转录层。首先CNN部分用于底层的特征提取,RNN采取了BiLSTM,用于学习关联序列信息并预测标签分布,CTC用于序列对齐,输出预测结果。
在这里插入图片描述
为了将特征输入到Recurrent Layers,做如下处理:

  • 首先会将图像缩放到 32×W×3 大小
  • 然后经过CNN后变为 1×(W/4)× 512
  • 接着针对LSTM,设置 T=(W/4) , D=512 ,即可将特征输入LSTM。

以上是理想训练时的操作,但是CRNN论文提到的网络输入是归一化好的100×32大小的灰度图像,即高度统一为32个像素。下面是CRNN的深度神经网络结构图,CNN采取了经典的VGG16,值得注意的是,在VGG16的第3第4个max pooling层CRNN采取的是1×2的矩形池化窗口(w×h),这有别于经典的VGG16的2×2的正方形池化窗口,这个改动是因为文本图像多数都是高较小而宽较长,所以其feature map也是这种高小宽长的矩形形状,如果使用1×2的池化窗口则更适合英文字母识别(比如区分i和l)。VGG16部分还引入了BatchNormalization模块,旨在加速模型收敛。还有值得注意一点,CRNN的输入是灰度图像,即图像深度为1。CNN部分的输出是512x1x16(c×h×w)的特征向量。
在这里插入图片描述
接下来分析RNN层。RNN部分使用了双向LSTM,隐藏层单元数为256,CRNN采用了两层BiLSTM来组成这个RNN层,RNN层的输出维度将是(s,b,class_num) ,其中class_num为文字类别总数。

值得注意的是:Pytorch里的LSTM单元接受的输入都必须是3维的张量(Tensors).每一维代表的意思不能弄错。第一维体现的是序列(sequence)结构,第二维度体现的是小块(mini-batch)结构,第三位体现的是输入的元素(elements of input)。如果在应用中不适用小块结构,那么可以将输入的张量中该维度设为1,但必须要体现出这个维度。

LSTM的输入

input of shape (seq_len, batch, input_size): tensor containing the features of the input sequence. 
The input can also be a packed variable length sequence.
input shape(a,b,c)
a:seq_len  -> 序列长度
b:batch
c:input_size   输入特征数目 

根据LSTM的输入要求,我们要对CNN的输出做些调整,即把CNN层的输出调整为[seq_len, batch, input_size]形式,下面为具体操作:先使用squeeze函数移除h维度,再使用permute函数调整各维顺序,即从原来[w, b, c]的调整为[seq_len, batch, input_size],具体尺寸为[16,batch,512],调整好之后即可以将该矩阵送入RNN层。

x = self.cnn(x)
b, c, h, w = x.size()
# print(x.size()): b,c,h,w
assert h == 1   # "the height of conv must be 1"
x = x.squeeze(2)  # remove h dimension, b *512 * width
x = x.permute(2, 0, 1)  # [w, b, c] = [seq_len, batch, input_size]
x = self.rnn(x)

RNN层输出格式如下,因为我们采用的是双向BiLSTM,所以输出维度将是hidden_unit * 2

Outputs: output, (h_n, c_n)
output of shape (seq_len, batch, num_directions * hidden_size)
h_n of shape (num_layers * num_directions, batch, hidden_size)
c_n (num_layers * num_directions, batch, hidden_size) 

然后我们再通过线性变换操作self.embedding1 = torch.nn.Linear(hidden_unit * 2, 512)是的输出维度再次变为512,继续送入第二个LSTM层。第二个LSTM层后继续接线性操作torch.nn.Linear(hidden_unit * 2, class_num)使得整个RNN层的输出为文字类别总数。

import torch
import torch.nn.functional as F


class Vgg_16(torch.nn.Module):

    def __init__(self):
        super(Vgg_16, self).__init__()
        self.convolution1 = torch.nn.Conv2d(1, 64, 3, padding=1)
        self.pooling1 = torch.nn.MaxPool2d(2, stride=2)
        self.convolution2 = torch.nn.Conv2d(64, 128, 3, padding=1)
        self.pooling2 = torch.nn.MaxPool2d(2, stride=2)
        self.convolution3 = torch.nn.Conv2d(128, 256, 3, padding=1)
        self.convolution4 = torch.nn.Conv2d(256, 256, 3, padding=1)
        self.pooling3 = torch.nn.MaxPool2d((1, 2), stride=(2, 1)) # notice stride of the non-square pooling
        self.convolution5 = torch.nn.Conv2d(256, 512, 3, padding=1)
        self.BatchNorm1 = torch.nn.BatchNorm2d(512)
        self.convolution6 = torch.nn.Conv2d(512, 512, 3, padding=1)
        self.BatchNorm2 = torch.nn.BatchNorm2d(512)
        self.pooling4 = torch.nn.MaxPool2d((1, 2), stride=(2, 1))
        self.convolution7 = torch.nn.Conv2d(512, 512, 2)

    def forward(self, x):
        x = F.relu(self.convolution1(x), inplace=True)
        x = self.pooling1(x)
        x = F.relu(self.convolution2(x), inplace=True)
        x = self.pooling2(x)
        x = F.relu(self.convolution3(x), inplace=True)
        x = F.relu(self.convolution4(x), inplace=True)
        x = self.pooling3(x)
        x = self.convolution5(x)
        x = F.relu(self.BatchNorm1(x), inplace=True)
        x = self.convolution6(x)
        x = F.relu(self.BatchNorm2(x), inplace=True)
        x = self.pooling4(x)
        x = F.relu(self.convolution7(x), inplace=True)
        return x  # b*512x1x16


class RNN(torch.nn.Module):
    def __init__(self, class_num, hidden_unit):
        super(RNN, self).__init__()
        self.Bidirectional_LSTM1 = torch.nn.LSTM(512, hidden_unit, bidirectional=True)
        self.embedding1 = torch.nn.Linear(hidden_unit * 2, 512)
        self.Bidirectional_LSTM2 = torch.nn.LSTM(512, hidden_unit, bidirectional=True)
        self.embedding2 = torch.nn.Linear(hidden_unit * 2, class_num)

    def forward(self, x):
        x = self.Bidirectional_LSTM1(x)   # LSTM output: output, (h_n, c_n)
        T, b, h = x[0].size()   # x[0]: (seq_len, batch, num_directions * hidden_size)
        x = self.embedding1(x[0].view(T * b, h))  # pytorch view() reshape as [T * b, nOut]
        x = x.view(T, b, -1)  # [16, b, 512]
        x = self.Bidirectional_LSTM2(x)
        T, b, h = x[0].size()
        x = self.embedding2(x[0].view(T * b, h))
        x = x.view(T, b, -1)
        return x  # [16,b,class_num]


# output: [s,b,class_num]
class CRNN(torch.nn.Module):
    def __init__(self, class_num, hidden_unit=256):
        super(CRNN, self).__init__()
        self.cnn = torch.nn.Sequential()
        self.cnn.add_module('vgg_16', Vgg_16())
        self.rnn = torch.nn.Sequential()
        self.rnn.add_module('rnn', RNN(class_num, hidden_unit))

    def forward(self, x):
        x = self.cnn(x)
        b, c, h, w = x.size()
        # print(x.size()): b,c,h,w
        assert h == 1   # "the height of conv must be 1"
        x = x.squeeze(2)  # remove h dimension, b *512 * width
        x = x.permute(2, 0, 1)  # [w, b, c] = [seq_len, batch, input_size]
        # x = x.transpose(0, 2)
        # x = x.transpose(1, 2)
        x = self.rnn(x)
        return x

损失函数设计

刚刚完成了CNN层和RNN层的设计,现在开始设计转录层,即将RNN层输出的结果翻译成最终的识别文字结果,从而实现不定长的文字识别。pytorch没有内置的CTC loss,所以只能去Github下载别人实现的CTC loss来完成损失函数部分的设计。安装CTC-loss的方式如下:

git clone https://github.com/SeanNaren/warp-ctc.git
cd warp-ctc
mkdir build; cd build
cmake ..
make
cd ../pytorch_binding/
python setup.py install
cd ../build
cp libwarpctc.so ../../usr/lib

待安装完毕后,我们可以直接调用CTC loss了,以一个小例子来说明ctc loss的用法。

import torch
from warpctc_pytorch import CTCLoss
ctc_loss = CTCLoss()
# expected shape of seqLength x batchSize x alphabet_size
probs = torch.FloatTensor([[[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]]]).transpose(0, 1).contiguous()
labels = torch.IntTensor([1, 2])
label_sizes = torch.IntTensor([2])
probs_sizes = torch.IntTensor([2])
probs.requires_grad_(True)  # tells autograd to compute gradients for probs
cost = ctc_loss(probs, labels, probs_sizes, label_sizes)
cost.backward()
CTCLoss(size_average=False, length_average=False)
    # size_average (bool): normalize the loss by the batch size (default: False)
    # length_average (bool): normalize the loss by the total number of frames in the batch. If True, supersedes size_average (default: False)

forward(acts, labels, act_lens, label_lens)
    # acts: Tensor of (seqLength x batch x outputDim) containing output activations from network (before softmax)
    # labels: 1 dimensional Tensor containing all the targets of the batch in one large sequence
    # act_lens: Tensor of size (batch) containing size of each output sequence from the network
    # label_lens: Tensor of (batch) containing label length of each example

从上面的代码可以看出,CTCLoss的输入为[probs, labels, probs_sizes, label_sizes],即预测结果、标签、预测结果的数目和标签数目。那么我们仿照这个例子开始设计CRNN的CTC LOSS。

preds = net(image)
preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size))  # preds.size(0)=w=16
cost = criterion(preds, text, preds_size, length) / batch_size   # 这里的length就是包含每个文本标签的长度的list,除以batch_size来求平均loss
cost.backward()

网络训练设计

接下来我们需要完善具体的训练流程,我们还写了个trainBatch函数用于bacth形式的梯度更新。

def trainBatch(net, criterion, optimizer, train_iter):
    data = train_iter.next()
    cpu_images, cpu_texts = data
    batch_size = cpu_images.size(0)
    lib.dataset.loadData(image, cpu_images)
    t, l = converter.encode(cpu_texts)
    lib.dataset.loadData(text, t)
    lib.dataset.loadData(length, l)

    preds = net(image)
    #print("preds.size=%s" % preds.size)
    preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size))  # preds.size(0)=w=22
    cost = criterion(preds, text, preds_size, length) / batch_size  # length= a list that contains the len of text label in a batch
    net.zero_grad()
    cost.backward()
    optimizer.step()
    return cost

整个网络训练的流程如下:CTC-LOSS对象->CRNN网络对象->image,text,len的tensor初始化->优化器初始化,然后开始循环每个epoch,指定迭代次数就进行模型验证和模型保存。CRNN论文提到所采用的优化器是Adadelta,但是经过我实验看来,Adadelta的收敛速度非常慢,所以改用了RMSprop优化器,模型收敛速度大幅度提升。

    criterion = CTCLoss()

    net = Net.CRNN(n_class)
    print(net)

    net.apply(lib.utility.weights_init)

    image = torch.FloatTensor(Config.batch_size, 3, Config.img_height, Config.img_width)
    text = torch.IntTensor(Config.batch_size * 5)
    length = torch.IntTensor(Config.batch_size)

    if cuda:
        net.cuda()
        image = image.cuda()
        criterion = criterion.cuda()

    image = Variable(image)
    text = Variable(text)
    length = Variable(length)

    loss_avg = lib.utility.averager()

    optimizer = optim.RMSprop(net.parameters(), lr=Config.lr)
    #optimizer = optim.Adadelta(net.parameters(), lr=Config.lr)
    #optimizer = optim.Adam(net.parameters(), lr=Config.lr,
                           #betas=(Config.beta1, 0.999))

    for epoch in range(Config.epoch):
        train_iter = iter(train_loader)
        i = 0
        while i < len(train_loader):
            for p in net.parameters():
                p.requires_grad = True
            net.train()

            cost = trainBatch(net, criterion, optimizer, train_iter)
            loss_avg.add(cost)
            i += 1

            if i % Config.display_interval == 0:
                print('[%d/%d][%d/%d] Loss: %f' %
                      (epoch, Config.epoch, i, len(train_loader), loss_avg.val()))
                loss_avg.reset()

            if i % Config.test_interval == 0:
                val(net, test_dataset, criterion)

            # do checkpointing
            if i % Config.save_interval == 0:
                torch.save(
                    net.state_dict(), '{0}/netCRNN_{1}_{2}.pth'.format(Config.model_dir, epoch, i))

训练过程与测试设计

下面这幅图表示的就是CRNN训练过程,文字类别数为6732,一共训练20个epoch,batch_Szie设置为64,所以一共是51244次迭代/epoch。
在这里插入图片描述
在迭代4个epoch时,loss降到0.1左右,acc上升到0.98。
在这里插入图片描述
接下来我们设计推断预测部分的代码,首先需初始化CRNN网络,载入训练好的模型,读入待预测的图像并resize为高为32的灰度图像,接着讲该图像送入网络,最后再将网络输出解码成文字即可输出。

import time
import torch
import os
from torch.autograd import Variable
import lib.convert
import lib.dataset
from PIL import Image
import Net.net as Net
import alphabets
import sys
import Config

os.environ['CUDA_VISIBLE_DEVICES'] = "4"

crnn_model_path = './bs64_model/netCRNN_9_48000.pth'
IMG_ROOT = './test_images'
running_mode = 'gpu'
alphabet = alphabets.alphabet
nclass = len(alphabet) + 1


def crnn_recognition(cropped_image, model):
    converter = lib.convert.strLabelConverter(alphabet)  # 标签转换

    image = cropped_image.convert('L')  # 图像灰度化

    ### Testing images are scaled to have height 32. Widths are
    # proportionally scaled with heights, but at least 100 pixels
    w = int(image.size[0] / (280 * 1.0 / Config.infer_img_w))
    #scale = image.size[1] * 1.0 / Config.img_height
    #w = int(image.size[0] / scale)

    transformer = lib.dataset.resizeNormalize((w, Config.img_height))
    image = transformer(image)
    if torch.cuda.is_available():
        image = image.cuda()
    image = image.view(1, *image.size())
    image = Variable(image)

    model.eval()
    preds = model(image)

    _, preds = preds.max(2)
    preds = preds.transpose(1, 0).contiguous().view(-1)

    preds_size = Variable(torch.IntTensor([preds.size(0)]))
    sim_pred = converter.decode(preds.data, preds_size.data, raw=False)  # 预测输出解码成文字
    print('results: {0}'.format(sim_pred))


if __name__ == '__main__':

    # crnn network
    model = Net.CRNN(nclass)
    
    # 载入训练好的模型,CPU和GPU的载入方式不一样,需分开处理
    if running_mode == 'gpu' and torch.cuda.is_available():
        model = model.cuda()
        model.load_state_dict(torch.load(crnn_model_path))
    else:
        model.load_state_dict(torch.load(crnn_model_path, map_location='cpu'))

    print('loading pretrained model from {0}'.format(crnn_model_path))

    files = sorted(os.listdir(IMG_ROOT))  # 按文件名排序
    for file in files:
        started = time.time()
        full_path = os.path.join(IMG_ROOT, file)
        print("=============================================")
        print("ocr image is %s" % full_path)
        image = Image.open(full_path)

        crnn_recognition(image, model)
        finished = time.time()
        print('elapsed time: {0}'.format(finished - started))

识别效果和总结

首先我从测试集中抽取几张图像送入模型识别,识别全部正确。在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
我也随机在一些文档图片、扫描图像上截取了一段文字图像送入我们该模型进行识别,识别效果也挺好的,基本识别正确,表明模型泛化能力很强。
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
我还截取了增值税扫描发票上的文本图像来看看我们的模型能否还可以表现出稳定的识别效果:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
这里做个小小的总结:对于端到端不定长的文字识别,CRNN是最为经典的识别算法,而且实战看来效果非常不错。上面识别结果可以看出,虽然我们用于训练的数据集是自己生成的,但是我们该模型对于pdf文档、扫描图像等都有很不错的识别结果,如果需要继续提升对特定领域的文本图像的识别,直接大量加入该类图像用于训练即可。

评论 10
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值