最近经常出现一个错误,在模型训练的时候loss:inf,如果出现的不多的话还是可以接受的,但是一旦这个大量出现,模型就不能训练了,损失也很难收敛,所以今天我终于把这个问题解决了,写下来表示分享。
经过分析,是输入长度和标签长度之间的问题,网上说要求输入长度要大于标签长度,我看了一下我的输入长度13,标签长度10,符合要求,但是依然出现错误,我换了一个模型之后输入长度14,标签长度10,问题消失,得出结论,输入长度要高于标签长度一部分,至于高出多少,应该考虑识别的字符串中重复并且相邻的字符数,简单来说就是尽量的多一些吧,目前没有分析增加输入长度对性能的影响,至少肉眼感觉不出来,但是影响性能是肯定的。
下面说怎么增加输入长度,我们知道用CTCloss的时候需要有四个输入,分别如下:
后两个参数就是输入长度和标签长度,标签长度肯定是没法改的,这个需求是固定的,所以只能改input_length,后来我发现input_length是和网络结构相关的,如图:
可以看到,基础网络传入CTC中的尺寸是(15,37),这个15和我们的input_length就有关系了ÿ