语音识别系列6-语音识别CTC之tensorflow前向计算

语音识别系列6-语音识别CTC之tensorflow前向计算

一、介绍

根据上一节的介绍,我们的模型就训练好了,但是如何使用模型呢,这一节我们进行介绍,模型训练完后,会在模型保存目录下生成至少四个文件,分别为:

checkpoint、model.ckpt-*.data-00000-of-00001、model.ckpt-*.index、model.ckpt-*.meta

 

二、源码解析

2.1首先,我们会把这四个文件打包成一个pb文件以方便使用,代码如下:

文件(test_inference.py):

# -*- coding: utf-8 -*-

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from os.path import join, isfile, abspath
import sys 
import time
from setproctitle import setproctitle
import shutil
import yaml
import os

from model_ctc import CTC 
from basic_util import mkdir_join, mkdir
from basic_util import count_total_parameters
import numpy as np
import math

from tensorflow.python.framework import graph_util

left_context = 10
right_context = 10
skip = 4

def general_frame(feature, seq_len):
  max_len, mfcc_len = feature.shape
  feat_num = left_context + right_context + 1 
  frame_list = [np.concatenate([feature[0 if m<0 else seq_len-1
               if m>seq_len-1 else m] for m in range(n-left_context,
               n+right_context+1)]) for n in range(0, seq_len, skip)]
  new_seq_len = math.ceil(seq_len / skip)
  new_feature = np.asarray(frame_list).astype(np.float32)
  return new_feature, new_seq_len

def parse_function(example_proto):
  features = {'feature': tf.VarLenFeature(tf.string),
              'label'  : tf.VarLenFeature(tf.string),
              'seq_len': tf.FixedLenFeature([], tf.int64)}
  parsed_features = tf.parse_single_example(example_proto, features)
  feature = parsed_features['feature']
  feature = tf.sparse_tensor_to_dense(parsed_features['feature'], default_value=b'0.0')
  feature = tf.decode_raw(feature[0], tf.float32)
  feature = tf.reshape(feature, [-1, 40])
  label = parsed_features['label']
  label = tf.sparse_tensor_to_dense(parsed_features['label'], default_value=b'0')
  label = tf.decode_raw(label[0], tf.int64)
  seq_len = parsed_features['seq_len']
  feature, seq_len = tf.py_func(general_frame, [feature, seq_len], [tf.float32, tf.int64])
  seq_len = tf.cast(seq_len, tf.int32)
  return feature, label, seq_len


def dense_to_sparse(dense):
  indices = []
  values = []
  for n, seq in enumerate(dense):
    seq = np.append(seq, -1)
    seq = seq[:np.argmin(seq)]
    indices.extend(zip([n] * len(seq), range(len(seq))))
    values.extend(seq)
  indices = np.asarray(indices, dtype=np.int64)
  values = np.asarray(values, dtype=np.int32)
  shape = np.asarray(dense.shape, dtype=np.int64)
  return indices, values, shape

def do_inference(model, params):
  # Tell TensorFlow that the model will be built into the default graph
  with tf.Graph().as_default(), tf.device('/cpu:0'):
    # NOTE: /cpu:0 is prepared for evaluation
    with tf.variable_scope(tf.get_variable_scope()):
      model.create_placeholders()
      logits = model._build(model.inputs_pl_list[0], model.inputs_seq_len_pl_list[0], model.keep_prob_pl_list[0], is_training=False)
      logits=tf.nn.softmax(logits, name="logits_softmax")
      decode = model.decoder(logits, model.inputs_seq_len_pl_list[0], beam_width=params['beam_width'])
      ler = model.compute_ler(decode, model.labels_pl_list[0])

    init_op = tf.global_variables_initializer()
    saver = tf.train.Saver(max_to_keep=None)

    train_dataset = tf.data.TFRecordDataset(params['train_data_file'])
    train_dataset = train_dataset.map(parse_function)
    train_dataset = train_dataset.padded_batch(params['batch_size'], padded_shapes=([None, None],
               [None], []), padding_values=(0.0, tf.cast(-1, tf.int64), tf.cast(0, tf.int32)))
    train_dataset = train_dataset.repeat(1)

    iterator = train_dataset.make_initializable_iterator()

    batch_feat, batch_label, batch_seq_len = iterator.get_next()
    s_indices, s_value, s_shape = tf.py_func(dense_to_sparse, [batch_label],  [tf.int64, tf.int32, tf.int64])
    batch_label = tf.SparseTensor(s_indices, s_value, s_shape)

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,log_device_placement=False)) as sess:
      sess.run(init_op)
      start_time_train = time.time()
      step_id = 0
      last_id = 0
      ckpt = tf.train.latest_checkpoint(model.save_path)
      print("===========================")
      print("ckpt:", ckpt)
      if ckpt != None:
        saver.restore(sess, ckpt)
        ind = ckpt.rfind("-")
        last_id = int(ckpt[ind + 1:])
        print("++++++++++++++++++++++++++++")
        print("last_id: ", last_id)
      print("===========================")
      start_time_test = time.time()
      sess.run(iterator.initializer)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值