Tensorflow API中LSTM的参数提取及手工复现模型推理(不使用API复现)
对于深度学习任务,使用Tensorflow的API进行训练,推理是目前主流的实现方式。但是训练模型和运用模型进行推理,往往处于不同的工作场景。比如训练模型使用服务器的GPU集群进行加速训练,而通常希望运用模型的场景是在嵌入式设备上,ARM、FPGA、或者像我的需求一样,需要设计ASIC来进行加速推理(反正应用场景就是没有GPU也多半装不上TensorFlow )。
附:本文不提供LSTM模型训练的教程,本文适用于需要从Tensorflow 模型中提取LSTM参数,并且以及需要手动复现的读者。代码由Python编写。提取的模型适用于TensorFlow 1.14及之前的版本。介绍的两种LSTM API手动复现时偏置计算时有细微区别,请读者注意。
LSTM的结构
由于这些设备上往往只能使用训练好的参数进行计算,但是成本限制或者产品需求不能使用API,因此必须要搞清楚网络结构,并且手动复现这个计算过程,以及提取出模型保存好的参数来进行推理。首先介绍LSTM的基本结构:
[结构图引自博客 ]https://blog.youkuaiyun.com/kami0116/article/details/94749564.
其中的c_prev和h_prev是前一时刻的状态,c,h是当前时刻的输出。通常LSTM中会有多个隐藏层,隐藏层可以理解成多个LSTM的串接,前一个LSTM的输出作为下一个LSTM的输入。但是在实际应用中,如FPGA或者ASIC实现,从硬件思维来考虑,由于LSTM对于时间连续性的要求,使得多个LSTM和单个LSTM在算一路数据而言,其耗时相当。c_prev和h_prev是一组向量,而第一个LSTM单元的c_prev和h_prev被默认为全为0的向量值。
Wf,Wi,Wj,Wo则是LSTM的权重信息,j 在文献中通常会用C来表示,为了区分以及方便对照着这个结构图写代码,此处用 j 表示。
公式中括号的[x,hprev]表示向量的拼接,[x,hprev] · Wf ,代表矩阵乘法,硬件中的操作是MAC(乘累加),*则是矩阵中相同位置的数字相乘没有累加这一操作。σ是sigmoid激活函数,tanh也是激活函数。
LSTM结构介绍完毕,代码分析。
TensorFlow的API中,有两个构造LSTM的函数:
1:
tf.contrib.rnn.BasicLSTMCell()
2:
tf.contrib.cudnn_rnn.