采用示例https://github.com/apache/incubator-mxnet/tree/master/example/speech_recognition
训练命令为
python main.py --configfile deepspeech.cfg
在文件stt_layer_warpctc.py的 net = mx.sym.Concat(*fc_seq, dim=0) 这一行的下面添加一行,
mx.viz.print_summary(net, shape={"data": (1, 1600, 161)})
即可以打印出网络结构,不过打印结果有点恐怖,居然有292020行,也就是29万多行,贴不出来,贴一下直接网页就卡死了
我还是想办法,把参数改小一点,比如在deepspeech.cfg中,把参数改为:
buckets = [20, 40]
num_rnn_layer = 3
num_hidden_rnn_list = [1760, 1760, 1760]
上面的语句也相应的改为
mx.viz.print_summary(net, shape={"data": (1, 40, 161)})
打印出来的网络结构,也挺大的,1150行,还是粘不下,我上传到了百度网盘
链接:https://pan.baidu.com/s/1uJ_0_vmgVBTW7fSkYIM_uw
提取码:7rvv
___________________________________________
上面这一行也可以采用 mx.viz.plot_network(net, shape={"data": (1, 40, 161)}).view(),会把网络结构保存为一个pdf文件,我把该pdf文件上传到了 https://download.youkuaiyun.com/download/zhqh100/12116047,可以参考一下
————————————————————————
上面的方法有点弱,但是上传到优快云的资源自己居然无法删除(感觉好恶心),,只好先留着
另一种打印方法是在文件stt_layer_warpctc.py的 “return net” 前添加一行
mx.viz.plot_network(net, shape={"data": (1, 40, 161), "label":(1, 183)}).view()
也会把网络结构保存为一个pdf文件,我也上传到了
链接:https://pan.baidu.com/s/1uJ_0_vmgVBTW7fSkYIM_uw
提取码:7rvv
————————————————————————
上面代码中的参数值,我是在文件stt_io_bucketingiter.py中,
self.provide_label = [('label', (self.batch_size, self.maxLabelLength))]
这一行的下面添加了打印
print(self.provide_data)
print(self.provide_label)
获取到的
_____________________________________________
本文完
return mx.io.DataBatch(data_all, label_all, pad=0,
bucket_key=self.buckets[i],
provide_data=provide_data,
provide_label=self.provide_label)