前提知识
BN层包括mean var gamma beta四个参数,。对于图像来说(4,3,2,2),一组特征图,一个通道的特征图对应一组参数,即四个参数均为维度为通道数的一维向量,图中gamma、beta参数维度均为[1,3]
其中gamma、beta为可学习参数(在pytorch中分别改叫weight和bias),训练时通过反向传播更新;
而running_mean、running_var则是在前向时先由X计算出mean和var,再由mean和var以动量momentum来更新running_mean和running_var。所以在训练阶段,running_mean和running_var在每次前向时更新一次;在测试阶段,不用再计算均值方差,则通过net.eval()固定该BN层的running_mean和running_var,此时这两个值即为训练阶段最后一次前向时确定的值,并在整个测试阶段保持不变。
class torch.nn.BatchNorm1d(num_features, eps=1e-05, momentum=0.1, affine=True)
nn.BatchNorm2d(self, num_features, eps=1e-5, momentum=0.1, affine=Tru