批量规范化(batch normalization)
可持续加速深层网络的的收敛速度。再结合残差块,批量规范化使得研究人员可以训练100层以上的网络。
为什么需要批量规范化层呢?
1、数据预处理的方式通常会对最终结果产生巨大影响;
2、中间层的变量可能具有更广变化范围,由于可变值的范围不同,是否需要对学习率进行调整;
3、深层的网络很复杂,容易过拟合。
批量规范化层的在训练模式中,通过小批量统计数据进行规范化;在预测模式中通过数据集统计进行规范化。
从头开始实现一个具有张量的批量规范化层
import torch
from torch import nn
from d2l import torch as d2l
def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):
# 通过is_grad_enabled来判断当前模式是训练模式还是预测模式
if not torch.is_grad_enabled():
# 如果是在预测模式下,直接使用传入的移动平均所得的均值和方差
X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
else:
assert len(X.shape) in (2, 4)
if len(X.shape) == 2:
# 使用全连接层的情况,计算特征维上的均值和方差
mean = X.mean(dim=0)
var = ((X - mean) ** 2).mean(dim=0)
else:
# 使用二维卷积层的情况,计算通道维上(axis=1)的均值和方差。
# 这里我们需要保持X的形状以便后面可以做广播运算
mean = X.mean(dim=(0, 2, 3), keepdim=True)
var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=