Pytorch中的BatchNorm

本文详细解析了BatchNorm的归一化流程,并重点剖析了momentum的作用,通过实例演示展示了其在更新均值和方差时的计算方式。此外,文中还提及了momentum的默认设置和其在实际应用中的意义。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

前言

首先放一张四种归一化的对比

请添加图片描述

BatchNorm

最基础的这个。原论文给出的流程:

请添加图片描述
就是先对送进来的数据计算均值和方差,然后归一化,最后再拟合一下。
再对应一下pytorch中的公式。没问题,可以完美对应。

y = x − E [ x ] Var ⁡ [ x ] + ϵ ∗ γ + β y=\frac{x-\mathrm{E}[x]}{\sqrt{\operatorname{Var}[x]+\epsilon}} * \gamma+\beta y=Var[x]+ϵ xE[x]γ+β

先看看Pytorch怎么使用的

torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

x = torch.randn([1,3,255,255])
bn = torch.nn.BatchNorm2d(3)
# bn.eval()
y = bn(x)

print(bn.running_mean)  # 均值
print(bn.running_var)   # 方差

tensor([-0.0029,  0.0019,  0.0008])
tensor([0.9955, 0.9959, 1.0055])

有以下的地方需要注意:
(1)bn.eval()会使bn.training=False,running_mean和running_var不再更新;train()则相反。
(2)bn.running_mean为均值,bn.running_var是方差。

再看看bn都包含什么
请添加图片描述
可以发现:
(1)running_mean和running_var是初始值分别是0和1,而维度是输入channel的值。
(2)weight和bias默认为1和0
(3)training=True(你可能会说track_running_stats=True,但我至今没发现这两项的实际作用有啥区别),running_mean和running_var默认随数据更新。

        if self.training and self.track_running_stats:
            # TODO: if statement only here to tell the jit to skip emitting this when it is None
            if self.num_batches_tracked is not None:  # type: ignore[has-type]
                self.num_batches_tracked = self.num_batches_tracked + 1  # type: ignore[has-type]
                if self.momentum is None:  # use cumulative moving average
                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                else:  # use exponential moving average
                    exponential_average_factor = self.momentum

也就是training 和 track_running_stats同时为True的时候才更新状态。

(4)还有一项难以解释的,momentum默认为0.1

momentum

这一项看了好半天,去看原论文Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift,里面没提到(给出的流程在上面。)

而pytorch官方解释也是不太清楚
请添加图片描述
明显,momentum用于更新均值和方差。但是这个x的上标是怎么回事?对拟合后的数据进行计算均值和方差?

咱还是想办法试一下吧。用两个输入x1~N(0,1) 和 x2~N(1,4)

# x1~N(0,1)
x1 = torch.randn([1,3,255,255])
bn = torch.nn.BatchNorm2d(3)
y1 = bn(x1)

print(bn.running_mean)
print(bn.running_var)
print('data:',torch.mean(y1),torch.var(y1))

# x2~N(1,4)
x2 = torch.randn([1,3,255,255]) * 2 +1
y2 = bn(x2)

print(bn.running_mean)
print(bn.running_var)
print('data:',torch.mean(y2),torch.var(y2))


tensor([ 0.0001,  0.0001, -0.0006])
tensor([1.0004, 1.0002, 0.9998])
data: tensor(0., grad_fn=<MeanBackward0>) tensor(1.0000, grad_fn=<VarBackward0>)
tensor([0.1006, 0.1004, 0.1006])
tensor([1.2988, 1.3021, 1.2999])
data: tensor(2.9724e-09, grad_fn=<MeanBackward0>) tensor(1.0000, grad_fn=<VarBackward0>)

此时没有反向传播,所以weight和bias一直是1和0。首先来看,bn的输出y1,y2~N(0,1)达到了预期效果。在执行y1 = bn(x1)后的running_mean和bn.running_var也很好解释。

问题在于在执行y2 = bn(x2)后,running_mean和bn.running_var分别是0.1和1.3,这是怎么来的?

看Pytorch的解释的意思,难道是计算E((1-m)*x1+m*x2)和D((1-m)*x1+m*x2),结果怎么算怎么不对。

实际上应该是这样
running_mean = (1-m)*E(x1) + m*E(x2) = (1-0.1)*0+0.1*1 = 0.1
running_var = (1-m)*D(x1) + m*D(x2) = (1-0.1)*1+0.1*4 = 1.3

参考文献

  1. Group Normalization
  2. Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift
<think>嗯,用户想了解如何在PyTorch中固定BatchNorm2d层,特别是在推理阶段禁用统计更新或者调试模型。那我得先回忆一下PyTorchBatchNorm2d的工作原理。根据之前的引用资料,BatchNorm2d在训练阶段会计算每个batch的均值和方差,并更新running_mean和running_var。而在测试阶段,应该使用这些统计量而不是当前batch的数据。首先,用户可能是在训练过程中或者加载预训练模型后,在推理时发现结果不一致,这可能是因为BatchNorm层没有正确固定。这时候,应该将模型设置为eval模式,使用model.eval(),这样BatchNorm层就不会更新running_mean和running_var了。引用4中提到,如果在测试阶段不设置eval模式,即使不更新参数,BN的统计量还是会变化,导致结果不同。另外,用户可能还想知道除了eval模式之外的其他方法,比如设置track_running_stats=False,但根据引用4,这个参数控制是否跟踪运行时的统计量。如果设为False,可能就不会使用保存的running_mean和running_var,而是用当前batch的数据,这可能在调试时有用。不过一般情况下,在推理时应该保持track_running_stats=True,并且用eval模式。还有,用户可能关心如何冻结BatchNorm层的参数,比如weight和bias。虽然这些参数在训练时是通过反向传播更新的,但在某些情况下可能需要固定它们。可以通过设置requires_grad=False来实现,但需要注意这不会影响running_mean和running_var的更新,除非同时使用eval模式。 此外,引用3提到手动计算BatchNorm的结果可能和代码结果不一致,可能是因为参数设置的问题。所以在固定BatchNorm时,确保所有相关设置正确,比如momentum参数,但默认的0.1通常不需要更改。 总结下来,关键步骤是:1)使用model.eval()切换到测试模式;2)确保track_running_stats=True;3)必要时冻结可学习参数。同时要注意,如果模型中有多个子模型,需要分别设置它们的eval模式,如引用4中提到的联合训练的情况。</think>在PyTorch中固定`BatchNorm2d`层的统计更新,主要通过控制模型状态和参数设置实现。以下是具体操作方法和原理分析: --- ### **步骤1:切换到推理模式** 调用`model.eval()`将模型设置为**评估模式**,这会自动冻结所有`BatchNorm2d`层的`running_mean`和`running_var`更新: ```python model = MyModel() # 包含BatchNorm2d层的模型 model.eval() # 固定BatchNorm的统计量 ``` - **原理**:在评估模式下,`BatchNorm2d`的`momentum`参数失效,停止从当前batch计算新统计量[^4]。 - **验证方法**:检查`BatchNorm2d.training`属性是否为`False`: ```python print(bn_layer.training) # 输出应为False ``` --- ### **步骤2:控制参数更新(可选)** 若需完全冻结`BatchNorm2d`的可学习参数(`weight`和`bias`),需手动设置`requires_grad=False`: ```python for module in model.modules(): if isinstance(module, nn.BatchNorm2d): module.weight.requires_grad = False module.bias.requires_grad = False ``` - **注意**:此操作仅冻结参数梯度,不影响`running_mean/running_var`的更新,需配合`model.eval()`使用[^3]。 --- ### **步骤3:禁用跟踪统计量(高级场景)** 通过设置`track_running_stats=False`,强制`BatchNorm2d`始终使用当前batch的统计量(类似`InstanceNorm`): ```python bn_layer = nn.BatchNorm2d(64, track_running_stats=False) ``` - **适用场景**:调试或特殊需求场景(如动态适应输入分布)[^4]。 --- ### **关键原理总结** 1. **训练模式**:`BatchNorm2d`通过动量公式更新统计量: $$running\_mean = (1 - momentum) \times running\_mean + momentum \times batch\_mean$$[^2] 2. **推理模式**:直接使用预存的`running_mean`和`running_var`归一化输入: $$output = \frac{input - running\_mean}{\sqrt{running\_var + \epsilon}} \times weight + bias$$[^1] ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值