4.BN推导

本文深入解析Batch Normalization(BN)层的工作原理,包括其在神经网络中的作用、公式表达、参数推导及实际应用注意事项。BN层能加速训练过程,提高模型泛化能力,但也有其局限性,如对小batch大小的依赖。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

  参考博客:https://www.cnblogs.com/guoyaohua/p/8724433.html

  参考知乎:https://www.zhihu.com/question/38102762/answer/85238569

1.BN的原理

  我们知道,神经网络在训练的时候,如果对图像做白化(即通过变换将数据变成均值为0,方差为1)的话,训练效果就会好。那么BN其实就是做了一个推广,它对隐层的输出也做了归一化的操作。那么为什么归一化操作能够使得训练效果好那么多呢?机器学习领域有个很重要的假设:IID独立同分布假设,就是假设训练数据和测试数据是满足相同分布的,这是通过训练数据获得的模型能够在测试集获得好的效果的一个基本保障。假设我们用批量梯度下降法来训练,在没有BN层之前,你对输入做了白化,对于每一个batch在喂到神经网络之前分布是一致的,但是通过每一个隐层的激活函数后,每一个batch的隐层输出分布可能就有很大差异了,而且随着层数的加深,这种差异也会越来越大。

2.BN的公式表达

 

  下面具体解释一下上述公式的意思:前三步的作用就是把输入数据分布归一化成一个正态分布,这样一来,每一个隐藏层未经激活的输出值经过激活函数之后大部分都落在了激活函数的线性区域(参考sigmoid函数),但是由于神经网络强大的表达能力就是基于它的高度非线性化,如果这种特性失去了,网络在再深也没有用了。原作者采取了一个办法,就是第四步,这个步骤其实就是将正太分布进行一个偏移(scale and shift),从某种意义上来说,就是在线性和非线性之间做一个权衡,而这个偏移的参数gamma和bata是神经网络在训练时学出来的。另外提一下,公式3中的eplison应该是一个很小的正数,为了防止分母为0,tensorflow的源码是这样解释的。

3.BN的参数推导

  现在我们看下,如何通过链式法则更新BN中的参数(其中就相当于之前的传播误差):

 

4.BN在实际运用中需要注意的问题 (此处参考博客:https://blog.youkuaiyun.com/xys430381_1/article/details/85141702 

  当我们的测试样本前向传导的时候,上面的均值u、标准差σ要哪里来?其实网络一旦训练完毕,参数都是固定的,这个时候即使是每批训练样本进入网络,那么BN层计算的均值u、和标准差σ都是固定不变的。我们可以采用这些训练阶段的均值u、标准差σ来作为测试样本所需要的均值、标准差,于是最后测试阶段的u和σ计算公式如下:

上面简单理解就是:对于均值来说直接计算所有batch u值的平均值;然后对于标准偏差采用每个batch σB的无偏估计。最后测试阶段,BN的使用公式就是:

也就是说, 在test的时候,BN用的是固定的mean和var, 而这个固定的mean和var是通过训练过程中对mean和var进行滑动平均得到的,被称之为moving_mean和moving_var。在实际操作中,每次训练时应当更新一下moving_mean和moving_var,然后把BN层的这些参数保存下来,留作测试和预测时使用。(如果不太了解滑动平均可以参考博文:https://blog.youkuaiyun.com/qq_18888869/article/details/83009504

5.BN的好处(此处参考知乎:https://www.zhihu.com/question/38102762龙鹏-言有三的回答)

  (1) 减轻了对参数初始化的依赖,这是利于调参的朋友们的。

  (2) 训练更快,可以使用更高的学习率。

  (3) BN一定程度上增加了泛化能力,dropout等技术可以去掉。

6.BN的缺陷

  batch normalization依赖于batch的大小,当batch值很小时,计算的均值和方差不稳定。研究表明对于ResNet类模型在ImageNet数据集上,batch从16降低到8时开始有非常明显的性能下降,在训练过程中计算的均值和方差不准确,而在测试的时候使用的就是训练过程中保持下来的均值和方差

  由于这一个特性,导致batch normalization不适合以下的几种场景。

  (1)batch非常小,比如训练资源有限无法应用较大的batch,也比如在线学习等使用单例进行模型参数更新的场景。

  (2)rnn,因为它是一个动态的网络结构,同一个batch中训练实例有长有短,导致每一个时间步长必须维持各自的统计量,这使得BN并不能正确的使用。在rnn中,对bn进行改进也非常的困难。不过,困难并不意味着没人做,事实上现在仍然可以使用的,不过这超出了咱们初识境的学习范围。

转载于:https://www.cnblogs.com/tangweijqxx/p/10678935.html

class AxialAttention(nn.Module): def __init__(self, in_planes, out_planes, groups=8, kernel_size=56, stride=1, bias=False, width=False): assert (in_planes % groups == 0) and (out_planes % groups == 0) super(AxialAttention, self).__init__() self.in_planes = in_planes self.out_planes = out_planes self.groups = groups self.group_planes = out_planes // groups self.kernel_size = kernel_size self.stride = stride self.bias = bias self.width = width # Multi-head self attention self.qkv_transform = qkv_transform(in_planes, out_planes * 2, kernel_size=1, stride=1, padding=0, bias=False) self.bn_qkv = nn.BatchNorm1d(out_planes * 2) self.bn_similarity = nn.BatchNorm2d(groups * 3) #self.bn_qk = nn.BatchNorm2d(groups) #self.bn_qr = nn.BatchNorm2d(groups) #self.bn_kr = nn.BatchNorm2d(groups) self.bn_output = nn.BatchNorm1d(out_planes * 2) # Position embedding self.relative = nn.Parameter(torch.randn(self.group_planes * 2, kernel_size * 2 - 1), requires_grad=True) query_index = torch.arange(kernel_size).unsqueeze(0) key_index = torch.arange(kernel_size).unsqueeze(1) relative_index = key_index - query_index + kernel_size - 1 self.register_buffer('flatten_index', relative_index.view(-1)) if stride > 1: self.pooling = nn.AvgPool2d(stride, stride=stride) self.reset_parameters() def forward(self, x): if self.width: x = x.permute(0, 2, 1, 3) else: x = x.permute(0, 3, 1, 2) # N, W, C, H N, W, C, H = x.shape x = x.contiguous().view(N * W, C, H) # Transformations qkv = self.bn_qkv(self.qkv_transform(x)) q, k, v = torch.split(qkv.reshape(N * W, self.groups, self.group_planes * 2, H), [self.group_planes // 2, self.group_planes // 2, self.group_planes], dim=2) # Calculate position embedding all_embeddings = torch.index_select(self.relative, 1, self.flatten_index).view(self.group_planes * 2, self.kernel_size, self.kernel_size) q_embedding, k_embedding, v_embedding = torch.split(all_embeddings, [self.group_planes // 2, self.group_planes // 2, self.group_planes], dim=0) qr = torch.einsum('bgci,cij->bgij', q, q_embedding) kr = torch.einsum('bgci,cij->bgij', k, k_embedding).transpose(2, 3) qk = torch.einsum('bgci, bgcj->bgij', q, k) stacked_similarity = torch.cat([qk, qr, kr], dim=1) stacked_similarity = self.bn_similarity(stacked_similarity).view(N * W, 3, self.groups, H, H).sum(dim=1) #stacked_similarity = self.bn_qr(qr) + self.bn_kr(kr) + self.bn_qk(qk) # (N, groups, H, H, W) similarity = F.softmax(stacked_similarity, dim=3) sv = torch.einsum('bgij,bgcj->bgci', similarity, v) sve = torch.einsum('bgij,cij->bgci', similarity, v_embedding) stacked_output = torch.cat([sv, sve], dim=-1).view(N * W, self.out_planes * 2, H) output = self.bn_output(stacked_output).view(N, W, self.out_planes, 2, H).sum(dim=-2) if self.width: output = output.permute(0, 2, 1, 3) else: output = output.permute(0, 2, 3, 1) if self.stride > 1: output = self.pooling(output) return output def reset_parameters(self): self.qkv_transform.weight.data.normal_(0, math.sqrt(1. / self.in_planes)) #nn.init.uniform_(self.relative, -0.1, 0.1) nn.init.normal_(self.relative, 0., math.sqrt(1. / self.group_planes)) #AxialBlock:中间通道plane×2=输出通道。inc=outc时,plane=0.5×inc。 class AxialBlock(nn.Module): expansion = 2 def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None, kernel_size=56): super(AxialBlock, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d width = int(planes * (base_width / 64.)) # Both self.conv2 and self.downsample layers downsample the input when stride != 1 self.conv_down = conv1x1(inplanes, width) self.bn1 = norm_layer(width) self.hight_block = AxialAttention(width, width, groups=groups, kernel_size=kernel_size) self.width_block = AxialAttention(width, width, groups=groups, kernel_size=kernel_size, stride=stride, width=True) self.conv_up = conv1x1(width, planes * self.expansion) self.bn2 = norm_layer(planes * self.expansion) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride def forward(self, x): identity = x out = self.conv_down(x) out = self.bn1(out) out = self.relu(out) out = self.hight_block(out) out = self.width_block(out) out = self.relu(out) out = self.conv_up(out) out = self.bn2(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out 如何动态计算YOLOv8中AxialAttention的合规kernel_size?有没有项目内自带的指令,类似于model.load()之类的兼容项目的指令?
最新发布
07-06
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值