基于生成对抗网络的手写数字图像生成

使用生成对抗网络实现手写数字生成。目标是依据MNIST数据集生成新的手写数字图像。

数据直接运行就能下载

import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import DataLoader, sampler
import torchvision
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from torch.autograd import Variable

batch_sizes = 64
Noise_dim = 96 #生成器输入噪声维度

def show_images(images):
    images = np.reshape(images, [images.shape[0], -1])  # images reshape to (batch_size, D)
    sqrtn = int(np.ceil(np.sqrt(images.shape[0])))
    sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))
    fig = plt.figure(figsize=(sqrtn, sqrtn))
    gs = gridspec.GridSpec(sqrtn, sqrtn)
    gs.update(wspace=0.05, hspace=0.05)
    for i, img in enumerate(images):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(img.reshape([sqrtimg, sqrtimg]))
    plt.show()
    return
    
#预处理
def preprocess_img(x):
    x = torchvision.transforms.ToTensor()(x)
    return (x - 0.5) / 0.5
#反预处理
def deprocess_img(x):
    return (x + 1.0) / 2.0

#定义一个采样器,继承pytorch的sampler.Sampler基类
class ChunkSampler(sampler.Sampler):
    def __init__(self, num_samples, start=0):
        self.num_samples = num_samples
        self.start = start

    def __iter__(self):
        return iter(range(self.start, self.start + self.num_samples))

    def __len__(self):
        return self.num_samples
    
    
mnist_train_set = torchvision.datasets.MNIST('../data', train=True, download=True, transform=preprocess_img)
mnist_test_set = torchvision.datasets.MNIST('../data', train=False, download=True, transform=preprocess_img)
train_loader = DataLoader(mnist_train_set, batch_size=batch_sizes, sampler=ChunkSampler(55000, 0))
val_data = DataLoader(mnist_train_set, batch_size=batch_sizes, sampler=ChunkSampler(5000, 55000))

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#imgs = deprocess_img(train_loader.__iter__().next()[0].view(batch_sizes, 784)).numpy().squeeze()
# images, _ = next(iter(train_loader))  #获取加载数据集中的第一批次数据
# imgs = deprocess_img(images.view(batch_sizes, 784)).numpy()
# show_images(imgs)

#定义判别器
class discriminator(nn.Module):
    
    def __init__(self):
        super(discriminator, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, 5, 1),
            nn.LeakyReLU(0.01),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, 5, 1),
            nn.LeakyReLU(0.01),
            nn.MaxPool2d(2, 2)
        )
        self.fc = nn.Sequential(
            nn.Linear(1024, 1024),
            nn.LeakyReLU(0.01),
            nn.Linear(1024, 1)
        )

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.shape[0], -1)
        x = self.fc(x)
        return x
    

#定义生成器
class generator(nn.Module):

    def __init__(self, noise_dim = Noise_dim):
        super(generator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(noise_dim, 1024),
            nn.ReLU(True),
            nn.BatchNorm1d(1024),
            nn.Linear(1024, 7*7*128),
            nn.ReLU(True),
            nn.BatchNorm1d(7*7*128)
        )
        self.conv = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, 2, padding=1),
            nn.ReLU(True),
            nn.BatchNorm2d(64),
            nn.ConvTranspose2d(64, 1, 4, 2, padding=1),
        )

    def forward(self, x):
        x = self.fc(x)
        x = x.view(x.shape[0], 128, 7, 7)
        x = self.conv(x)
        return x
    
#定义损失函数

bce_loss = nn.BCEWithLogitsLoss()
#判别器损失函数
def discriminator_loss(logits_real, logits_fake):
    size = logits_real.shape[0]
    true_labels = Variable(torch.ones(size, 1)).float().to(device)
    false_labels = Variable(torch.zeros(size, 1)).float().to(device)
    loss = bce_loss(logits_real, true_labels) + bce_loss(logits_fake, false_labels)
    return loss
#生成器损失函数
def generator_loss(logits_fake):
    size = logits_fake.shape[0]
    true_labels = Variable(torch.ones(size, 1)).float().to(device)
    loss = bce_loss(logits_fake, true_labels)
    return loss

#搭建生成对抗网络
def Train_GAN(D, G, D_optimizer, G_optimizer, discriminator_loss, generator_loss, show_every=250,
                batch_size=batch_sizes, noise_dim=Noise_dim, num_epochs=10):
    iter_count = 0 #记录处理的批次数量
    for epoch in range(num_epochs):
        for x, _ in train_loader:  #遍历训练集,忽略标签
            bs = x.shape[0]  #获取批次大小  x: (64, 1, 28, 28)
            real_data = Variable(x).to(device)    
            logits_real = D(real_data)  #判别器判别真实图像
            #-1~1均匀分布的噪声
            sample_noise = (torch.rand(bs, noise_dim) - 0.5) / 0.5   #生成噪声
            g_fake_seed = Variable(sample_noise).to(device)   #转化为Variable
            fake_images = G(g_fake_seed)   #生成假图像
            logits_fake = D(fake_images)   #判别器判别假图像
            #更新判别器
            d_total_error = discriminator_loss(logits_real, logits_fake)   #计算判别器损失
            D_optimizer.zero_grad()    #梯度清零
            d_total_error.backward()   #反向传播
            D_optimizer.step()    #权重更新
            #更新生成器
            g_fake_seed = Variable(sample_noise).to(device)  #重新生成噪声
            fake_images = G(g_fake_seed)  #重新生成假图像
            g_logits_fake = D(fake_images)   #判别器判别假图像
            g_error = generator_loss(g_logits_fake)   #计算生成器损失
            G_optimizer.zero_grad()   #梯度清零
            g_error.backward()    #反向传播
            G_optimizer.step()    #更新生成器的权重
            if (iter_count % show_every == 0):
                print('Iter: {}, D: {:.4}, G:{:.4}'.format(iter_count, d_total_error.item(), g_error.item()))
                imgs_numpy = deprocess_img(fake_images.data.cpu().numpy())
                show_images(imgs_numpy[0:16])
                print()
            iter_count += 1

#初始化网络
D = discriminator().to(device)
G = generator().to(device)
D_optimizer = torch.optim.Adam(D.parameters(), lr=3e-4, betas=(0.5, 0.999))
G_optimizer = torch.optim.Adam(G.parameters(), lr=3e-4, betas=(0.5, 0.999))
Train_GAN(D, G, D_optimizer, G_optimizer, discriminator_loss, generator_loss, num_epochs=10)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值