问题背景
最近使用DNN模型来做排序,发现近几天的NDCG离线指标下跌得很厉害。于是下载模型自己在本地评测了一下,预测结果都是NaN,于是把各层的模型参数以及各层的输出都打印出来,发现BatchNormalize中的moving_variance(方差)的某一维是NaN,最后一查果然是这一维特征异常了。为了把事情弄清楚,写这个blog记录一下。
BatchNormalize(BN)基础知识
BN的提出是为了解决神经网络中Internal Covariate Shift的问题,Internal Covariate Shift简单地说就是各层网络的输出会产生分布的变化,而分布变化使神经网络比较难收敛。BN的思想也很简单,就是通过把输入各个特征的分布转化成均值为0,方差为1的正态分布上去。训练过程是在mini-batch中进行操作的,具体也就是求出均值、标准差,通过下边式子对输入特征进行转化。
为了保留原来的特征分布信息,加入了可学习的参数:
可以看出,当伽马等于标准差,贝塔等于均值时就完全还原了输入。
但是在模型预测时mini-batch可能只有一个实例,所以在实现中的做法是保留每一个batch计算的均值和方差,通过平滑的方式计算预测使用的均值和方差:
问题复现
使用了如下代码进行了问题的复现:
import tensorflow as tf
import numpy as np
bn_input = tf.cast([[1,