手写中文文章识别(4)——模型搭建

手写中文文章识别(1)——问题描述

https://blog.youkuaiyun.com/foreseerwang/article/details/80833749

手写中文文章识别(2)——样本集构建

https://blog.youkuaiyun.com/foreseerwang/article/details/80842498

手写中文文章识别(3)——data feeding

https://blog.youkuaiyun.com/foreseerwang/article/details/80914473

 

作为Tensorflow的初级使用者,我认为,Tensorflow的程序架构可以大致分为3方面内容:data feeding、模型搭建、模型训练。本篇主要介绍模型搭建。

 

前文已经提到,本项目最终采用的模型结构为:CNN + 双向LSTM + Viterbi算法(CRF)。先上代码,再来解读。

def get_a_cell(lstm_size, keep_prob):
    lstm = tf.nn.rnn_cell.BasicLSTMCell(lstm_size)
    drop = tf.nn.rnn_cell.DropoutWrapper(lstm, output_keep_prob=keep_prob)
    return drop


def model(images, seq_lens, keep_prob, is_training):
    
    ## CNN
    with slim.arg_scope([slim.conv2d, slim.fully_connected],
                    normalizer_fn=slim.batch_norm,
                    normalizer_params={'is_training': is_training}):
        conv3_1 = slim.conv2d(images, 64, [3, 3], 1, padding='SAME', scope='conv3_1')
        max_pool_1 = slim.max_pool2d(conv3_1, [2, 2], [2, 2], padding='SAME', scope='pool1')
        conv3_2 = slim.conv2d(max_pool_1, 128, [3, 3], padding='SAME', scope='conv3_2')
        max_pool_2 = slim.max_pool2d(conv3_2, [2, 2], [2, 2], padding='SAME', scope='pool2')
        conv3_3 = slim.conv2d(max_pool_2, 256, [3, 3], padding='SAME', scope='conv3_3')
        max_pool_3 = slim.max_pool2d(conv3_3, [2, 2], [2, 2], padding='SAME', scope='pool3')
        conv3_4 = slim.conv2d(max_pool_3, 512, [3, 3], padding='SAME', scope='conv3_4')
        conv3_5 = slim.conv2d(conv3_4, 512, [3, 3], padding='SAME', scope='conv3_5')
        max_pool_4 = slim.max_pool2d(conv3_5, [2, 2], [2, 2], padding='SAME', scope='pool4')

        flatten = slim.flatten(max_pool_4)
        fc1_2d = slim.fully_connected(slim.dropout(flatten, keep_prob), FLAGS.char_vec_len,
                                   activation_fn=tf.nn.relu, scope='fc1')
        fc1 = tf.reshape(fc1_2d, [-1, FLAGS.sent_len_max, FLAGS.char_vec_len])
        
    ## 双向LSTM
    cell_fw = tf.nn.rnn_cell.MultiRNNCell(
        [get_a_cell(FLAGS.lstm_size, keep_prob) for _ in range(FLAGS.num_layers)]
    )

    cell_bw = tf.nn.rnn_cell.MultiRNNCell(
        [get_a_cell(FLAGS.lstm_size, keep_prob) for _ in range(FLAGS.num_layers)]
    )

    bi_lstm_outputs, final_state = tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, fc1,
                                                                   sequence_length=seq_lens,
                                                                   dtype=tf.float32)

    lstm_outputs = tf.concat(bi_lstm_outputs, -1)

    x = tf.reshape(lstm_outputs, [-1, 2 * FLAGS.lstm_size])

    ## 全连接
    softmax_w = tf.get_variable('w', [2 * FLAGS.lstm_size, charset_size],
                                initializer=tf.zeros_initializer())
    softmax_b = tf.get_variable('b', [charset_size],
                                initializer=tf.zeros_initializer())

    logits = tf.matmul(x, softmax_w) + softmax_b

    return logits


def build_graph(top_k, images, labels, seq_lens, mask, keep_prob,
                is_training, reuse_variables=False):

    with tf.variable_scope("encoder") as scope:
        if reuse_variables:
            tf.get_variable_scope().reuse_variables()

        logits = model(images, seq_lens, keep_prob, is_training)
        
        ## Viterbi解码
        if FLAGS.viterbi:
            logits_reshaped = tf.reshape(logits,[-1,FLAGS.sent_len_max,charset_size])
            labels_reshaped = tf.reshape(labels,[-1,FLAGS.sent_len_max])

            log_likelihood, transition_params = tf.contrib.crf.crf_log_likelihood(
                logits_reshaped,labels_reshaped,seq_lens)

            loss = tf.reduce_mean(-log_likelihood)

            labels_pred, _ = tf.contrib.crf.crf_decode(
                logits_reshaped, transition_params, seq_lens)
            logits_masked = tf.boolean_mask(tf.reshape(labels_pred, [-1]), mask)
            labels_masked = tf.boolean_mask(labels, mask)

            accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.cast(logits_masked, tf.int64),
                                                       tf.cast(labels_masked, tf.int64)),
                                              tf.float32))
        else:
            logits_masked = tf.boolean_mask(logits, mask)
            labels_masked = tf.boolean_mask(labels, mask)
            loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits_masked,
                                                                                 labels=labels_masked))

            accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(logits_masked, 1),
                                                       tf.cast(labels_masked, tf.int64)),
                                              tf.float32))

        ## 构建loss和train op
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        if update_ops:
            updates = tf.group(*update_ops)
            loss = control_flow_ops.with_dependencies([updates], loss)

        global_step = tf.get_variable("step", [], initializer=tf.constant_initializer(0.0),
                                      trainable=False)
        optimizer = tf.train.AdamOptimizer(learning_rate=0.001)
        train_op = slim.learning.create_train_op(loss, optimizer, global_step=global_step)

        tf.summary.scalar('loss', loss)
        tf.summary.scalar('accuracy', accuracy)
        merged_summary_op = tf.summary.merge_all()

        return {'images': tf.boolean_mask(images,mask),
                'labels': labels_masked,
                'logits_pred_masked': logits_masked,
                'keep_prob': keep_prob,
                'top_k': top_k,
                'global_step': global_step,
                'train_op': train_op,
                'loss': loss,
                'is_training': is_training,
                'accuracy': accuracy,
                'merged_summary_op': merged_summary_op,
                'mask': mask,
                'seq_lens': seq_lens}

 

这里面包括三个函数,其中第一个get_a_cell用于LSTM构建,没啥可说的;model是从image输入经过CNN和双向LSTM到logits的完整模型,这里的CNN模型和之前hwdb手写汉字识别参考博客中的一样,也很简单直观;build_graph函数的主要目的是构建用于训练的loss和train_op。

 

下面详细介绍build_graph函数。首先是下面这句代码:

if reuse_variables:
    tf.get_variable_scope().reuse_variables()

如果希望在训练的过程中,每隔一段时间观察CV集(cross validation)的精度情况,以控制训练过程,那么这句代码至关重要,它的作用是重用当前的各个模型参数(或者说当前训练出的各个模型参数,而不是重新建立新的一组参数)。详细解释可见:https://blog.youkuaiyun.com/foreseerwang/article/details/79499553

 

Viterbi算法代码:

logits_reshaped = tf.reshape(logits,[-1,FLAGS.sent_len_max,charset_size])
labels_reshaped = tf.reshape(labels,[-1,FLAGS.sent_len_max])

log_likelihood, transition_params = tf.contrib.crf.crf_log_likelihood(
                logits_reshaped,labels_reshaped,seq_lens)

loss = tf.reduce_mean(-log_likelihood)

labels_pred, _ = tf.contrib.crf.crf_decode(
                logits_reshaped, transition_params, seq_lens)
logits_masked = tf.boolean_mask(tf.reshape(labels_pred, [-1]), mask)
labels_masked = tf.boolean_mask(labels, mask)

accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.cast(logits_masked, tf.int64),
                                           tf.cast(labels_masked, tf.int64)),
                                  tf.float32))

这是Viterbi算法的典型代码,其核心是通过crf_log_likelihood函数学习出一个转换矩阵transition_params,并把它用在crf_decode函数中以考虑前后文,获得预测labels。除了crf_decode,tf中还有一个viterbi_decode,二者的功能一样,只是crf_decode在tensorflow中完成,而viterbi_decode需要在tensorflow外部完成。这里给出了是否打开Viterbi算法的开关,实测中发现,打开Viterbi算法,相比于不使用Viterbi算法精度上确实会有提升,但也会带来训练时间20倍以上的增加。

 

loss和train_op的构建都是常规代码,倒没什么可说的了。

 

以上。很遗憾,这部分代码没有单独的验证。完整的训练和validation结果将在下一篇文章中给出。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值