自定义BatchNormal中running mean与running var在nn.DataParallel下不更新的问题

问题描述

自定义实现一个类似BatchNormal的功能层,在单GPU情况下,动量更新变量running mean均值与running var方差可以正常随训练迭代累积更新,但换上多GPU环境(使用nn.DataParallel包装模型)时模型性能下降非常离谱,然后查看了一下那个BatchNormal层的running mean和running var,发现每次迭代都从初始的0和1开始,原本应该是累积更新的,而且最后在保存state_dict()时,保存的也是0和1

解决方案

找了好久都没找到有类似问题的,最后在PyTorch社区发现了

link

将=替换为.copy_

self.running_mean.copy_(...)
# instead of
self.running_mean = (...)
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值