使用 PyTorch 构建首个 GAN
1. 定义损失函数和优化器
首先,我们要为判别器网络定义损失函数,并为两个网络定义优化器:
criterion = nn.BCELoss()
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(0.5, 0.999))
这里, nn.BCELoss() 代表二元交叉熵损失函数。
2. 加载 MNIST 数据集到 GPU 内存
dataset = dset.MNIST(root=DATA_PATH, download=True,
transform=transforms.Compose([
transforms.Resize(X_DIM),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
]))
assert dataset
dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH
超级会员免费看
订阅专栏 解锁全文
2138

被折叠的 条评论
为什么被折叠?



