GAN网络学习路径
1、经典GAN网络的公式
2、网络工作原理
(1):从噪声z进行随机抽样,传入G网络,生成新数据G(z)和其概率分布Pg(G(z))
(2):将真实数据和G生成的新数据一起传入D网络进行真假判别,通过sigmoid函数来输出判定类别
(3):迭代优化D和G损失函数,根据D来调整G
(4):直到D和G达到收敛,即D无法判断G产生数据的真假性,即Pg(G(z))已经非常逼近Pdata(x)
4、基于pytorch的代码实现
可以看代码讲解很详细:小白读基于pytorch的GAN网络代码 - 知乎
import argparse
import os
import numpy as np
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch.nn as nn
import torch
os.makedirs("images", exist_ok=True)
parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=100, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=128, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="ad