def get_z0(self, batch, train=True):
n,c,h,w = batch.shape
if self.init_type == 'gaussian':
### standard gaussian #+ 0.5
cur_shape = (n, c, h, w)
return torch.randn(cur_shape)*self.noise_scale
else:
raise NotImplementedError("INITIALIZATION TYPE NOT IMPLEMENTED")
代码片段,生成高斯噪声
最新推荐文章于 2025-03-07 19:58:47 发布