看到了RL代码中经常用到的masked_whiten,最近产生了一个问题,为什么shift_mean=False的情况下,要whitened += mean,而不是whitened += mean*torch.rsqrt(var + 1e-8)给他还原回去?
def masked_whiten(values, mask, shift_mean=True):
"""Whiten values with masked values."""
mean, var = masked_mean(values, mask), masked_var(values, mask)
whitened = (values - mean) * torch.rsqrt(var + 1e-8)
if not shift_mean:
whitened += mean
return whitened
分析了一下均值和方差,whitened = (values - mean) * torch.rsqrt(var + 1e-8)后得到的均值为0,方差为1
- 如果whitened += mean,均值变成了mean,方差仍然为1,因为加上常数,方差不会发生变化
- 如果whitened += mean*torch.rsqrt(var + 1e-8),均值不再是mean,方差也不再是1

被折叠的 条评论
为什么被折叠?



