基础GAN的原理还不懂的,先看:生成式对抗神经网络(GAN)原理给你讲的明明白白
一、加载数据
# 数据归一化
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(0.5, 0.5)
])
# 加载内置数据
train_ds = torchvision.datasets.MNIST('data', # 当前目录下的data文件夹
train=True, # train数据
transform=transform,
download=True)
dataloader = torch.utils.data.DataLoader(train_ds, batch_size=64, shuffle=True)
- 没有下载过MNIST数据集的直接用上面的代码,会自动下载,下载速度看人品~~~~~
- 有MNIST数据集的,把数据文件放在程序目录下,代码中的‘data’更改为你的文件名,download改为False
- 自行数据下载地址:https://download.youkuaiyun.com/download/m0_62128864/85045154?spm=1001.2014.3001.5501
二、创建生成器
# 定义生成器
# 输入是长度为100的噪声(正态分布随机数)
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.gen = nn.Sequential(nn.Linear(100, 256),
nn.ReLU(),
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, 28 * 28),
nn.Tanh()
)
# 定义前向传播 x表示长度为