最近,一篇非常火的报道,使用pytorch 加 50 行核心代码模拟 GAN 对抗神经网络,自己尝试走了一遍,并对源码提出自己的理解。原文链接如下 https://medium.com/@devnag/generative-adversarial-networks-gans-in-50-lines-of-code-pytorch-e81b79659e3f#.nr3akui9z
code 地址 https://github.com/devnag/pytorch-generative-adversarial-networks
首先配置pytorch, 这个直接去github 上的pytorch官网按照教程去配,链接地址 https://github.com/pytorch/pytorch
之后把code 下下来,直接运行就可以了。
数据的产生,正样本也就是使用一个均值为4,标准差为1.25的高斯分布,
torch.Tensor(np.random.normal(mu, sigma, (1, n)))负样本是随机数
torch.rand(m, n)
两个网路的定义,一个是生成网络,将其参数打印出来如下
(map1): Linear (1 -> 50)
(map2): Linear (50 -> 50)
(map3): Linear (50 -> 1)
)
输入是(1L,1L)
另一个是判别网路,输入是(200L,1L)
Discriminator (
(map1): Linear (200 -> 50)
(map2):

本文介绍了如何使用50行PyTorch代码实现GAN,包括配置PyTorch环境、数据生成、生成网络与判别网络的定义。在初始化网络时,对数据进行预处理,通过对比不同返回方式的效果,说明了数据增强对于提升生成器性能的重要性。实验结果显示,返回data+diffs的方式能生成更接近真实数据的样本,因为其更高的维度提供了更强的学习能力。
最低0.47元/天 解锁文章
3571

被折叠的 条评论
为什么被折叠?



