torch.nn.GroupNorm
字面意思是分组做Normalization,官方说明在这里。
torch.nn.GroupNorm(num_groups, num_channels, eps=1e-05, affine=True, device=None, dtype=None)
计算公式
E[x]是x的均值;
Var[x]是标准差;
gama和beta是训练参数,如果不想使用,可以通过参数affine=False设置。默认为True;
eposilon是输入参数,防止Var为0,默认值为1e-05,可以通过参数eps修改。
输入张量要求
输入的张量至少是2维的,其中第一维度为Channel,后面的维度为特征数据。
使用示例
GroupNorm 是将第一维度的Channels按group分,然后每个group按照上面的计算公式做计算。
比如,
input shape = (4,5)
gn = GroupNorm (2,4)
output = gn(input)
那么output就是将4个channel的数据分为2组,前1-2channel为一组,并按公式计算;后3-4channel为一组,并按公式计算;但是这里输出的shape还是(4,5)
GroupNorm 不会改变输入张量的shape,它只是按照group做normalization
三维,四维以上都一样,比如这里的input shape =(4,1,2,3,4,5),GroupNorm 的作用仅仅针对第一维度的channel。
报错
如果GroupNorm 输入的channel num与输入不一致,则会报错
RuntimeError: Expected number of channels in input to be divisible by num_groups