问题描述
自定义实现一个类似BatchNormal的功能层,在单GPU情况下,动量更新变量running mean均值与running var方差可以正常随训练迭代累积更新,但换上多GPU环境(使用nn.DataParallel包装模型)时模型性能下降非常离谱,然后查看了一下那个BatchNormal层的running mean和running var,发现每次迭代都从初始的0和1开始,原本应该是累积更新的,而且最后在保存state_dict()时,保存的也是0和1
解决方案
找了好久都没找到有类似问题的,最后在PyTorch社区发现了
将=替换为.copy_
self.running_mean.copy_(...)
# instead of
self.running_mean = (...)