import torch
from torch import nn
from d2l import torch as d2l
# 从零开始
# 拉伸参数gamma、偏移参数beta
# moving_mean,moving_var:全局均值和方差 eps:避免除零 momentum:用来更新全局均值和方差(0.9or0.1
def batch_norm(X,gamma,beta,moving_mean,moving_var,eps,momentum):
# 通过`is_grad_enabled` 来判断当前模式是训练模式还是预测模式
if not torch.is_grad_enabled(