语音识别系列6-语音识别CTC之tensorflow前向计算
一、介绍
根据上一节的介绍,我们的模型就训练好了,但是如何使用模型呢,这一节我们进行介绍,模型训练完后,会在模型保存目录下生成至少四个文件,分别为:
checkpoint、model.ckpt-*.data-00000-of-00001、model.ckpt-*.index、model.ckpt-*.meta
二、源码解析
2.1首先,我们会把这四个文件打包成一个pb文件以方便使用,代码如下:
文件(test_inference.py):
# -*- coding: utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from os.path import join, isfile, abspath
import sys
import time
from setproctitle import setproctitle
import shutil
import yaml
import os
from model_ctc import CTC
from basic_util import mkdir_join, mkdir
from basic_util import count_total_parameters
import numpy as np
import math
from tensorflow.python.framework import graph_util
left_context = 10
right_context = 10
skip = 4
def general_frame(feature, seq_len):
max_len, mfcc_len = feature.shape
feat_num = left_context + right_context + 1
frame_list = [np.concatenate([feature[0 if m<0 else seq_len-1
if m>seq_len-1 else m] for m in range(n-left_context,
n+right_context+1)]) for n in range(0, seq_len, skip)]
new_seq_len = math.ceil(seq_len / skip)
new_feature = np.asarray(frame_list).astype(np.float32)
return new_feature, new_seq_len
def parse_function(example_proto):
features = {'feature': tf.VarLenFeature(tf.string),
'label' : tf.VarLenFeature(tf.string),
'seq_len': tf.FixedLenFeature([], tf.int64)}
parsed_features = tf.parse_single_example(example_proto, features)
feature = parsed_features['feature']
feature = tf.sparse_tensor_to_dense(parsed_features['feature'], default_value=b'0.0')
feature = tf.decode_raw(feature[0], tf.float32)
feature = tf.reshape(feature, [-1, 40])
label = parsed_features['label']
label = tf.sparse_tensor_to_dense(parsed_features['label'], default_value=b'0')
label = tf.decode_raw(label[0], tf.int64)
seq_len = parsed_features['seq_len']
feature, seq_len = tf.py_func(general_frame, [feature, seq_len], [tf.float32, tf.int64])
seq_len = tf.cast(seq_len, tf.int32)
return feature, label, seq_len
def dense_to_sparse(dense):
indices = []
values = []
for n, seq in enumerate(dense):
seq = np.append(seq, -1)
seq = seq[:np.argmin(seq)]
indices.extend(zip([n] * len(seq), range(len(seq))))
values.extend(seq)
indices = np.asarray(indices, dtype=np.int64)
values = np.asarray(values, dtype=np.int32)
shape = np.asarray(dense.shape, dtype=np.int64)
return indices, values, shape
def do_inference(model, params):
# Tell TensorFlow that the model will be built into the default graph
with tf.Graph().as_default(), tf.device('/cpu:0'):
# NOTE: /cpu:0 is prepared for evaluation
with tf.variable_scope(tf.get_variable_scope()):
model.create_placeholders()
logits = model._build(model.inputs_pl_list[0], model.inputs_seq_len_pl_list[0], model.keep_prob_pl_list[0], is_training=False)
logits=tf.nn.softmax(logits, name="logits_softmax")
decode = model.decoder(logits, model.inputs_seq_len_pl_list[0], beam_width=params['beam_width'])
ler = model.compute_ler(decode, model.labels_pl_list[0])
init_op = tf.global_variables_initializer()
saver = tf.train.Saver(max_to_keep=None)
train_dataset = tf.data.TFRecordDataset(params['train_data_file'])
train_dataset = train_dataset.map(parse_function)
train_dataset = train_dataset.padded_batch(params['batch_size'], padded_shapes=([None, None],
[None], []), padding_values=(0.0, tf.cast(-1, tf.int64), tf.cast(0, tf.int32)))
train_dataset = train_dataset.repeat(1)
iterator = train_dataset.make_initializable_iterator()
batch_feat, batch_label, batch_seq_len = iterator.get_next()
s_indices, s_value, s_shape = tf.py_func(dense_to_sparse, [batch_label], [tf.int64, tf.int32, tf.int64])
batch_label = tf.SparseTensor(s_indices, s_value, s_shape)
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,log_device_placement=False)) as sess:
sess.run(init_op)
start_time_train = time.time()
step_id = 0
last_id = 0
ckpt = tf.train.latest_checkpoint(model.save_path)
print("===========================")
print("ckpt:", ckpt)
if ckpt != None:
saver.restore(sess, ckpt)
ind = ckpt.rfind("-")
last_id = int(ckpt[ind + 1:])
print("++++++++++++++++++++++++++++")
print("last_id: ", last_id)
print("===========================")
start_time_test = time.time()
sess.run(iterator.initializer)