# batchnorm
1. 代码
import torch
from torch import nn
from d2l import torch as d2l、
# moving_mean, moving_var 近似看作整个数据集上的均值和方差,
# eps是epsilon, momentum(通常=0.9)用来更新moving_mean, moving_var
def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):
# 通过is_grad_enabled来判断当前模式是训练模式(enabled),还是预测模式(not is_grad_enabled)
if not torch.is_grad_enabled():
# 如果是在预测模式下,直接使用传入的moving_mean, moving_var做标准化
X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
else:
# X.shape=2 是全连接层,=4是2D卷积层
assert len(X.shape) in (2, 4)
if len(X.shape) == 2:
# 若为全连接,按列求特征的均值和方差
mean = X.mean(dim=0)
var = ((X - mean) ** 2).mean(dim=0)
else:
# 若2D卷积层, 计算axis=1的均值和方差
# X 保持形状以做broadcasting,(