BatchNorm原理以及PyTorch实现

BatchNorm算法

在这里插入图片描述
简单来说BatchNorm对输入的特征按照通道计算期望方差(第1和第2个公式),并标准化(第3个公式,减去均值,再除方差,变为均值0,方差1)。但这会降低网络的表达能力,因此,BN在标准化后还要进行缩放平移,也就是可学习的参数 γ \gamma γ β \beta β,也对应每个通道。

BatchNorm的原理并不清楚,可能是降低了Internal Covariate Shift,也可能是使得optimization landscape变得平滑

优点

  • 提高训练稳定性,可使用更大的learning rate、降低初始化参数的要求并可以构建更深更宽的网络;
  • 加速网络收敛。

缺点

  • 增加计算量和内存开销,降低推理速度;
  • 增加训练和推理时的差异;
  • 打破了minibatch之间的独立性;
  • 小batch效果差。

BatchNorm 在训练时,仅用当前Batch的均值和方差,而测试推理时,使用EMA计算的均值和方差。

PyTorch Code

nn.BatchNorm2d为例。其继承关系为:Module → \to _NormBase → \to _BatchNorm → \to BatchNorm2dModule 是所有PyTorch构建网络模块的父类。

_NormBase

_NormBase主要是注册和初始化参数

class _NormBase(Module):
    """Common base of _InstanceNorm and _BatchNorm"""
    def __init__(
        self,
        num_features: int, # 特征通道数
        eps: float = 1e-5,	# 防止分母为0
        momentum: float = 0.1, # 
        affine: bool = True, # 标准化后是否进行缩放,是否使用\gamma 和 \beta
        track_running_stats: bool = True, # 使用均值方差进行标准化
        device=None,
        dtype=None
    ) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(_NormBase, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.affine = affine
        self.track_running_stats = track_running_stats
        if self.affine:
            self.weight = Parameter(torch.empty(num_features, **factory_kwargs)) # 注册\gamma,后续初始化为1
            self.bias = Parameter(torch.empty(num_features, **factory_kwargs)) # 注册\beta,后续初始化为0
        else:
            self.register_parameter("weight", None)
            self.register_parameter("bias", None)
        if self.track_running_stats:
            self.register_buffer('running_mean', torch.zeros(num_features, **factory_kwargs)) # 注册期望,后续初始化为0
            self.register_buffer('running_var', torch.ones(num_features, **factory_kwargs)) # 注册方差,后续初始化为1
            self.running_mean: Optional[Tensor]
            self.running_var: Optional[Tensor]
            self.register_buffer('num_batches_tracked',
                                 torch.tensor(0, dtype=torch.long,
                                              **{k: v for k, v in factory_kwargs.items() if k != 'dtype'}))
            self.num_batches_tracked: Optional[Tensor]
        else:
            self.register_buffer("running_mean", None)
            self.register_buffer("running_var", None)
            self.register_buffer("num_batches_tracked", None)
        self.reset_parameters()

    def reset_running_stats(self) -> None:
        if self.track_running_stats:
            # running_mean/running_var/num_batches... are registered at runtime depending
            # if self.track_running_stats is on
            self.running_mean.zero_()  # type: ignore[union-attr]
            self.running_var.fill_(1)  # type: ignore[union-attr]
            self.num_batches_tracked.zero_()  # type: ignore[union-attr,operator]

	# 参数初始化,\gamma 为 1,\beta 为 0.
    def reset_parameters(self) -> None:
        self.reset_running_stats()
        if self.affine:
            init.ones_(self.weight)
            init.zeros_(self.bias)

    def _check_input_dim(self, input):
        raise NotImplementedError

_BatchNorm

调用nn.functional.batch_norm 对每个通道进行计算:

class _BatchNorm(_NormBase):
    def __init__(
        self,
        num_features,
        eps=1e-5,
        momentum=0.1,	# 见下一章节
        affine=True,
        track_running_stats=True,
        device=None,
        dtype=None
    ):
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(_BatchNorm, self).__init__(
            num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
        )

    def forward(self, input: Tensor) -> Tensor:
        self._check_input_dim(input)

        # exponential_average_factor is set to self.momentum
        # (when it is available) only so that it gets updated
        # in ONNX graph when this node is exported to ONNX.
        if self.momentum is None:
            exponential_average_factor = 0.0
        else:
            exponential_average_factor = self.momentum

        if self.training and self.track_running_stats:
            # TODO: if statement only here to tell the jit to skip emitting this when it is None
            if self.num_batches_tracked is not None:  # type: ignore[has-type]
                self.num_batches_tracked = self.num_batches_tracked + 1  # type: ignore[has-type]
                if self.momentum is None:  # use cumulative moving average
                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                else:  # use exponential moving average
                    exponential_average_factor = self.momentum

        r"""
        Decide whether the mini-batch stats should be used for normalization rather than the buffers.
        Mini-batch stats are used in training mode, and in eval mode when buffers are None.
        """
        if self.training:
            bn_training = True
        else:
            bn_training = (self.running_mean is None) and (self.running_var is None)

        r"""
        Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
        passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
        used for normalization (i.e. in eval mode when buffers are not None).
        """
        return F.batch_norm(
            input,
            # If buffers are not to be tracked, ensure that they won't be updated
            self.running_mean
            if not self.training or self.track_running_stats
            else None,
            self.running_var if not self.training or self.track_running_stats else None,
            self.weight,
            self.bias,
            bn_training,
            exponential_average_factor,
            self.eps,
        )

BatchNorm2d

特化了输入检查

class BatchNorm2d(_BatchNorm):
    def _check_input_dim(self, input):
        if input.dim() != 4:
            raise ValueError("expected 4D input (got {}D input)".format(input.dim()))

关于momentum参数

按照Pytorch注释,momentum参与running_meanrunning_var的计算。置为None时,简单计算平均(累积移动平均)。默认值为0.1。

_BatchNorm中,赋值给了

exponential_average_factor = self.momentum

当其不为None时,也就是指数平均(Exponential Moving Average, EMA)。其计算公式为:
x ˉ t = β μ t + ( 1 − β ) x ˉ t − 1 \bar{x}_t = \beta \mu_t + (1-\beta)\bar{x}_{t-1} xˉt=βμt+(1β)xˉt1
其中, μ t \mu_t μt是当前Batch的均值或方差, β \beta β为exponential_average_factor。展开
x ˉ t = β μ t + ( 1 − β ) ( β μ t − 1 + ( 1 − β ) ( β μ t − 2 + ( 1 − β ) x ˉ t − 3 ) ) = β μ t + ( 1 − β ) β μ t − 1 + ( 1 − β ) 2 β μ t − 2 + . . . + ( 1 − β ) t β μ 0 \begin{aligned} \bar{x}_t &= \beta \mu_t + (1-\beta)(\beta \mu_{t-1} + (1-\beta)(\beta \mu_{t-2} + (1-\beta)\bar{x}_{t-3}))\\\\ &= \beta \mu_t + (1-\beta)\beta \mu_{t-1} + (1-\beta)^2\beta \mu_{t-2} + ... + (1-\beta)^t\beta \mu_0 \end{aligned} xˉt=βμt+(1β)(βμt1+(1β)(βμt2+(1β)xˉt3))=βμt+(1β)βμt1+(1β)2βμt2+...+(1β)tβμ0
从公式可以看出,越靠近当前的数据占的比重越大,比重按指数衰减。其值约等于最近
1 β \frac{1}{\beta} β1
次的均值。

### WGAN 基本原理 Wasserstein GAN (WGAN) 是一种改进版的生成对抗网络(GAN),旨在解决原始 GAN 训练过程中遇到的一些问题,比如模式崩溃和训练不稳定等问题[^1]。传统的 GAN 使用的是 Jensen-Shannon 散度来衡量真实数据分布 \( P_r \) 和生成的数据分布 \( P_g \),这可能导致梯度消失现象,在某些情况下使得模型难以收敛。 为了克服这些问题,WGAN 引入了 Wasserstein 距离作为损失函数的基础。该距离也被称为 Earth Mover's Distance (EMD),它提供了更平滑的距离度量方式,并且对于概率分布之间的差异更加敏感。具体来说,给定两个分布 \( P_r \) 和 \( P_g \),它们之间 Wasserstein 距离定义为: \[ W(P_r, P_g)=\inf _{\gamma \in \Pi\left(P_{r}, P_{g}\right)} \mathbb{E}_{(x, y) \sim \gamma}[d(x, y)] \] 其中 \( d(\cdot,\cdot) \) 表示样本间的某种成本函数,通常取欧几里得范数;\( \Pi(P_r,P_g) \) 则表示所有联合分布在边缘上分别等于 \( P_r \) 和 \( P_g \) 的集合[^3]。 然而直接计算上述表达式非常困难,因此通过 Kantorovich-Rubinstein 对偶理论可以将其转换成更容易处理的形式: \[ W(P_r, P_g)=\sup _{|f|_L \leq 1} \mathbb{E}_{x \sim P_r}[f(x)]-\mathbb{E}_{y \sim P_g}[f(y)] \] 这里 \( |f|_L \leq 1 \) 意味着 Lipschitz 连续条件下的最大斜率为 1 。这个新的形式允许我们利用神经网络去近似最优传输映射 f ,从而简化优化过程[^4]。 ### PyTorch 实现教程 下面是一个简单的例子展示如何使用 PyTorch实现 WGAN : ```python import torch from torch import nn, optim class Generator(nn.Module): def __init__(self, input_dim=100, output_dim=784): # MNIST 图像大小为28*28=784像素点 super().__init__() self.model = nn.Sequential( *[ nn.Linear(input_dim, 256), nn.LeakyReLU(), nn.BatchNorm1d(256), nn.Linear(256, 512), nn.LeakyReLU(), nn.BatchNorm1d(512), nn.Linear(512, output_dim), nn.Tanh() # 输出范围[-1,+1], 需要预处理输入图像至相同区间 ] ) def forward(self, z): return self.model(z) class Critic(nn.Module): # 注意:在WGAN中称为Critic而不是Discriminator def __init__(self, img_shape=(1, 28, 28)): super().__init__() dim = int(torch.prod(torch.tensor(img_shape))) # 将多维张量展平后的维度 self.model = nn.Sequential( *[nn.Linear(dim, 512), nn.LeakyReLU()], *[nn.Linear(512, 256), nn.LeakyReLU()], nn.Linear(256, 1) ) def forward(self, imgs): imgs_flat = imgs.view(imgs.size(0), -1) validity = self.model(imgs_flat) return validity def compute_gradient_penalty(critic, real_samples, fake_samples): """Calculates the gradient penalty loss for WGAN GP""" Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor alpha = Tensor(np.random.random((real_samples.size(0), 1))) interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True) critic_interpolates = critic(interpolates) gradients = autograd.grad(outputs=critic_interpolates, inputs=interpolates, grad_outputs=torch.ones_like(critic_interpolates).to(device), create_graph=True)[0] gradients = gradients.view(gradients.size(0), -1) gradient_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean() return gradient_penalty # 初始化参数... latent_dim = 100 n_critic = 5 # 批次更新次数比例(C:D) clip_value = 0.01 # 参数裁剪阈值用于控制权重范围 [-c,c] learning_rate = 0.00005 batches_done = 0 # 绘图计数器初始化... generator = Generator(latent_dim, np.prod(img_shape)) critic = Critic() if cuda: generator.cuda(), critic.cuda() optimizer_G = optim.RMSprop(generator.parameters(), lr=learning_rate) optimizer_D = optim.RMSprop(critic.parameters(), lr=learning_rate) for epoch in range(n_epochs): for i, (imgs, _) in enumerate(dataloader): # Configure input real_imgs = Variable(imgs.type(Tensor)) # --------------------- # Train Discriminator/Critic # --------------------- optimizer_D.zero_grad() # Sample noise as generator input z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], latent_dim)))) # Generate a batch of images fake_imgs = generator(z).detach() # Real images real_validity = critic(real_imgs) # Fake images fake_validity = critic(fake_imgs) # Gradient penalty gradient_penalty = compute_gradient_penalty(critic, real_imgs.data, fake_imgs.data) # Adversarial loss with gradient penalty term added to objective function. d_loss = - torch.mean(fake_validity)) \ + lambda_gp * gradient_penalty d_loss.backward(retain_graph=True) optimizer_D.step() # Clip weights of discriminator/critic between (-c, c). for p in critic.parameters(): p.data.clamp_(-clip_value, clip_value) # Only update generator every n_critic iterations if i % n_critic == 0: # --- # Train Generator # ----------------- optimizer_G.zero_grad() # Generate new set of samples since last time D was updated gen_z = Variable(Tensor(np.random.normal(0, 1, (batch_size, latent_dim)))) gen_imgs = generator(gen_z) g_loss = -torch.mean(critic(gen_imgs)) g_loss.backward() optimizer_G.step() batches_done += 1 ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值