手写中文文章识别(1)——问题描述
https://blog.youkuaiyun.com/foreseerwang/article/details/80833749
手写中文文章识别(2)——样本集构建
https://blog.youkuaiyun.com/foreseerwang/article/details/80842498
手写中文文章识别(3)——data feeding
https://blog.youkuaiyun.com/foreseerwang/article/details/80914473
作为Tensorflow的初级使用者,我认为,Tensorflow的程序架构可以大致分为3方面内容:data feeding、模型搭建、模型训练。本篇主要介绍模型搭建。
前文已经提到,本项目最终采用的模型结构为:CNN + 双向LSTM + Viterbi算法(CRF)。先上代码,再来解读。
def get_a_cell(lstm_size, keep_prob):
lstm = tf.nn.rnn_cell.BasicLSTMCell(lstm_size)
drop = tf.nn.rnn_cell.DropoutWrapper(lstm, output_keep_prob=keep_prob)
return drop
def model(images, seq_lens, keep_prob, is_training):
## CNN
with slim.arg_scope([slim.conv2d, slim.fully_connected],
normalizer_fn=slim.batch_norm,
normalizer_params={'is_training': is_training}):
conv3_1 = slim.conv2d(images, 64, [3, 3], 1, padding='SAME', scope='conv3_1')
max_pool_1 = slim.max_pool2d(conv3_1, [2, 2], [2, 2], padding='SAME', scope='pool1')
conv3_2 = slim.conv2d(max_pool_1, 128, [3, 3], padding='SAME', scope='conv3_2')
max_pool_2 = slim.max_pool2d(conv3_2, [2, 2], [2, 2], padding='SAME', scope='pool2')
conv3_3 = slim.conv2d(max_pool_2, 256, [3, 3], padding='SAME', scope='conv3_3')
max_pool_3 = slim.max_pool2d(conv3_3, [2, 2], [2, 2], padding='SAME', scope='pool3')
conv3_4 = slim.conv2d(max_pool_3, 512, [3, 3], padding='SAME', scope='conv3_4')
conv3_5 = slim.conv2d(conv3_4, 512, [3, 3], padding='SAME', scope='conv3_5')
max_pool_4 = slim.max_pool2d(conv3_5, [2, 2], [2, 2], padding='SAME', scope='pool4')
flatten = slim.flatten(max_pool_4)
fc1_2d = slim.fully_connected(slim.dropout(flatten, keep_prob), FLAGS.char_vec_len,
activation_fn=tf.nn.relu, scope='fc1')
fc1 = tf.reshape(fc1_2d, [-1, FLAGS.sent_len_max, FLAGS.char_vec_len])
## 双向LSTM
cell_fw = tf.nn.rnn_cell.MultiRNNCell(
[get_a_cell(FLAGS.lstm_size, keep_prob) for _ in range(FLAGS.num_layers)]
)
cell_bw = tf.nn.rnn_cell.MultiRNNCell(
[get_a_cell(FLAGS.lstm_size, keep_prob) for _ in range(FLAGS.num_layers)]
)
bi_lstm_outputs, final_state = tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, fc1,
sequence_length=seq_lens,
dtype=tf.float32)
lstm_outputs = tf.concat(bi_lstm_outputs, -1)
x = tf.reshape(lstm_outputs, [-1, 2 * FLAGS.lstm_size])
## 全连接
softmax_w = tf.get_variable('w', [2 * FLAGS.lstm_size, charset_size],
initializer=tf.zeros_initializer())
softmax_b = tf.get_variable('b', [charset_size],
initializer=tf.zeros_initializer())
logits = tf.matmul(x, softmax_w) + softmax_b
return logits
def build_graph(top_k, images, labels, seq_lens, mask, keep_prob,
is_training, reuse_variables=False):
with tf.variable_scope("encoder") as scope:
if reuse_variables:
tf.get_variable_scope().reuse_variables()
logits = model(images, seq_lens, keep_prob, is_training)
## Viterbi解码
if FLAGS.viterbi:
logits_reshaped = tf.reshape(logits,[-1,FLAGS.sent_len_max,charset_size])
labels_reshaped = tf.reshape(labels,[-1,FLAGS.sent_len_max])
log_likelihood, transition_params = tf.contrib.crf.crf_log_likelihood(
logits_reshaped,labels_reshaped,seq_lens)
loss = tf.reduce_mean(-log_likelihood)
labels_pred, _ = tf.contrib.crf.crf_decode(
logits_reshaped, transition_params, seq_lens)
logits_masked = tf.boolean_mask(tf.reshape(labels_pred, [-1]), mask)
labels_masked = tf.boolean_mask(labels, mask)
accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.cast(logits_masked, tf.int64),
tf.cast(labels_masked, tf.int64)),
tf.float32))
else:
logits_masked = tf.boolean_mask(logits, mask)
labels_masked = tf.boolean_mask(labels, mask)
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits_masked,
labels=labels_masked))
accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(logits_masked, 1),
tf.cast(labels_masked, tf.int64)),
tf.float32))
## 构建loss和train op
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
if update_ops:
updates = tf.group(*update_ops)
loss = control_flow_ops.with_dependencies([updates], loss)
global_step = tf.get_variable("step", [], initializer=tf.constant_initializer(0.0),
trainable=False)
optimizer = tf.train.AdamOptimizer(learning_rate=0.001)
train_op = slim.learning.create_train_op(loss, optimizer, global_step=global_step)
tf.summary.scalar('loss', loss)
tf.summary.scalar('accuracy', accuracy)
merged_summary_op = tf.summary.merge_all()
return {'images': tf.boolean_mask(images,mask),
'labels': labels_masked,
'logits_pred_masked': logits_masked,
'keep_prob': keep_prob,
'top_k': top_k,
'global_step': global_step,
'train_op': train_op,
'loss': loss,
'is_training': is_training,
'accuracy': accuracy,
'merged_summary_op': merged_summary_op,
'mask': mask,
'seq_lens': seq_lens}
这里面包括三个函数,其中第一个get_a_cell用于LSTM构建,没啥可说的;model是从image输入经过CNN和双向LSTM到logits的完整模型,这里的CNN模型和之前hwdb手写汉字识别参考博客中的一样,也很简单直观;build_graph函数的主要目的是构建用于训练的loss和train_op。
下面详细介绍build_graph函数。首先是下面这句代码:
if reuse_variables:
tf.get_variable_scope().reuse_variables()
如果希望在训练的过程中,每隔一段时间观察CV集(cross validation)的精度情况,以控制训练过程,那么这句代码至关重要,它的作用是重用当前的各个模型参数(或者说当前训练出的各个模型参数,而不是重新建立新的一组参数)。详细解释可见:https://blog.youkuaiyun.com/foreseerwang/article/details/79499553
Viterbi算法代码:
logits_reshaped = tf.reshape(logits,[-1,FLAGS.sent_len_max,charset_size])
labels_reshaped = tf.reshape(labels,[-1,FLAGS.sent_len_max])
log_likelihood, transition_params = tf.contrib.crf.crf_log_likelihood(
logits_reshaped,labels_reshaped,seq_lens)
loss = tf.reduce_mean(-log_likelihood)
labels_pred, _ = tf.contrib.crf.crf_decode(
logits_reshaped, transition_params, seq_lens)
logits_masked = tf.boolean_mask(tf.reshape(labels_pred, [-1]), mask)
labels_masked = tf.boolean_mask(labels, mask)
accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.cast(logits_masked, tf.int64),
tf.cast(labels_masked, tf.int64)),
tf.float32))
这是Viterbi算法的典型代码,其核心是通过crf_log_likelihood函数学习出一个转换矩阵transition_params,并把它用在crf_decode函数中以考虑前后文,获得预测labels。除了crf_decode,tf中还有一个viterbi_decode,二者的功能一样,只是crf_decode在tensorflow中完成,而viterbi_decode需要在tensorflow外部完成。这里给出了是否打开Viterbi算法的开关,实测中发现,打开Viterbi算法,相比于不使用Viterbi算法精度上确实会有提升,但也会带来训练时间20倍以上的增加。
loss和train_op的构建都是常规代码,倒没什么可说的了。
以上。很遗憾,这部分代码没有单独的验证。完整的训练和validation结果将在下一篇文章中给出。