【Week-G7】Semi-Supervised GAN 实践,使用MNIST数据集

本次学习进行Semi-Supervised GAN的实践,数据集为MNIST
主要为了解惑:加入生成的图像本身就携带标签,比如数字1~9,那么:为什么还需要鉴别器判断输入图像的真假,而不直接判断图像属于0-9中的哪一个数字?

一、基础知识

本次学习使用到的SGAN将GAN扩展到半监督学习方式,通过强制判别器D来输出类别标签。具体结构如下图:
在这里插入图片描述

输入数据集:N类中某一个
生成器G:输出第N+1个类
判别器D:充当分类器C的效果
训练时:判别器D被用于预测输入时属于N+1类中的哪一个

SGAN可以用于训练效果更好的判别器D,并且比普通的GAN产生更加高质量的样本。
在这里插入图片描述

二、代码实现

2.1 导入所需模块 & 设置网络初始参数

import argparse
import os
import numpy as np
import math

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

os.makedirs("images", exist_ok=True)

parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=50, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--num_classes", type=int, default=10, help="number of classes for dataset")
parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值