DCGAN实现

本文围绕DCGAN模型展开,介绍其原理、数据集和训练框架。详细阐述模型代码分布及作用,包括数据处理、网络结构定义和模型训练。通过测试调参分析batchsize、epoch num和学习率对模型的影响,指出参数选择对训练稳定性和效果至关重要,还提及使用GPU加速训练。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1.项目简介

DCGAN,全称是 Deep Convolution Generative Adversarial Networks(深度卷积生成对抗网络),是 Alec Radfor 等人于2015年提出的一种模型。该模型在 Original GAN 的理论基础上,开创性地 将 CNN 和 GAN 相结合 以 实现对图像的处理,并 提出了一系列 对网络结构的限制 以 提高网络的稳定性。

DCGAN 的网络结构 在之后的各种改进 GAN 中得到了广泛的沿用,可以说是当今各类改进 GAN 的前身。

原理:

  1. 全卷积网络(all convolutional net:用步幅卷积(strided convolutions)替代确定性空间池化函数(deterministic spatial pooling functions)(比如最大池化),让网络自己学习downsampling方式。作者对 generator 和 discriminator 都采用了这种方法。
  2. 取消全连接层: 比如,使用 全局平均池化(global average pooling)替代 fully connected layer。global average pooling会降低收敛速度,但是可以提高模型的稳定性。GAN的输入采用均匀分布初始化,可能会使用全连接层(矩阵相乘),然后得到的结果可以reshape成一个4 dimension的tensor,然后后面堆叠卷积层即可;对于鉴别器,最后的卷积层可以先flatten,然后送入一个sigmoid分类器。
  3. 批归一化(Batch Normalization: 即将每一层的输入变换到0均值和单位方差(注:其实还需要shift 和 scale),BN 被证明是深度学习中非常重要的 加速收敛 和 减缓过拟合 的手段。这样有助于解决 poor initialization 问题并帮助梯度流向更深的网络。防止G把所有rand input都折叠到一个点。但是实践表明,将所有层都进行Batch Normalization,会导致样本震荡和模型不稳定,因此 只对生成器(G)的输出层和鉴别器(D)的输入层使用BN。
  4. Leaky Relu 激活函数: 生成器(G),输出层使用tanh 激活函数,其余层使用relu 激活函数。鉴别器(D),都采用leaky rectified activation。

数据集:

Cifar10   MNIST

训练框架:

torch(2.0.1+cu117)

2.模型代码分布及作用介绍

2.1依赖库导入

from __future__ import print_function

import argparse

import os

import random

import torch

import torch.nn as nn

import torch.nn.parallel

import torch.backends.cudnn as cudnn

import torch.optim as optim

import torch.utils.data

import torchvision.datasets as dset

import torchvision.transforms as transforms

import torchvision.utils as vutils

2.2定义参数

parser = argparse.ArgumentParser()

parser.add_argument('--dataset', required=True, help='cifar10 | lsun | mnist |imagenet | folder | lfw | fake')

parser.add_argument('--dataroot', required=False, help='path to dataset')

parser.add_argument('--workers', type=int, help='number of data loading workers', default=2)

parser.add_argument('--batchSize', type=int, default=64, help='input batch size')

parser.add_argument('--imageSize', type=int, default=64, help='the height / width of the input image to network')

parser.add_argument('--nz', type=int, default=100, help='size of the latent z vector')

parser.add_argument('--ngf', type=int, default=64)

parser.add_argument('--ndf', type=int, default=64)

parser.add_argument('--niter', type=int, default=25, help='number of epochs to train for')

parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, default=0.0002')

parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')

parser.add_argument('--cuda', action='store_true', default=False, help='enables cuda')

parser.add_argument('--dry-run', action='store_true', help='check a single training cycle works')

parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use')

parser.add_argument('--netG', default='', help="path to netG (to continue training)")

parser.add_argument('--netD', default='', help="path to netD (to continue training)")

parser.add_argument('--outf', default='.', help='folder to output images and model checkpoints')

parser.add_argument('--manualSeed', type=int, help='manual seed')

parser.add_argument('--classes', default='bedroom', help='comma separated list of classes for the lsun data set')

parser.add_argument('--mps', action='store_true', default=False, help='enables macOS GPU training')

2.3数据下载及处理

if opt.dataset in ['imagenet', 'folder', 'lfw']:

    # folder dataset

    dataset = dset.ImageFolder(root=opt.dataroot,

                               transform=transforms.Compose([

                                   transforms.Resize(opt.imageSize),

                                   transforms.CenterCrop(opt.imageSize),

                                   transforms.ToTensor(),

                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),

                               ]))

    nc = 3

elif opt.dataset == 'rsum':

    classes = [c + '_train' for c in opt.classes.split(',')]

    dataset = dset.LSUN(root=opt.dataroot, classes=classes,

                        transform=transforms.Compose([

                            transforms.Resize(opt.imageSize),

                            transforms.CenterCrop(opt.imageSize),

                            transforms.ToTensor(),

                            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),

                        ]))

    nc = 3

elif opt.dataset == 'cifar10':

    dataset = dset.CIFAR10(root=opt.dataroot, download=True,

                           transform=transforms.Compose([

                               transforms.Resize(opt.imageSize),

                               transforms.ToTensor(),

                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),

                           ]))

    nc = 3

elif opt.dataset == 'mnist':

    dataset = dset.MNIST(root=opt.dataroot, download=True,

                         transform=transforms.Compose([

                             transforms.Resize(opt.imageSize),

                             transforms.ToTensor(),

                             transforms.Normalize((0.5,), (0.5,)),

                         ]))

    nc = 1

elif opt.dataset == 'fake':

    dataset = dset.FakeData(image_size=(3, opt.imageSize, opt.imageSize),

                            transform=transforms.ToTensor())

    nc = 3

assert dataset

dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize,

                                         shuffle=True, num_workers=int(opt.workers))

use_mps = opt.mps and torch.backends.mps.is_available()

作用:

如果 opt.dataset 的取值为 'imagenet'、'folder' 或 'lfw',那么它将被视为一个文件夹数据集。代码将使用 dset.ImageFolder 加载该数据集,并进行一系列的图像预处理和标准化操作。

如果 opt.dataset 的取值为 'rsum',那么它将被视为 LSUN 数据集。代码将使用 dset.LSUN 加载该数据集,并进行相应的预处理和标准化操作。

如果 opt.dataset 的取值为 'cifar10',那么它将被视为 CIFAR-10 数据集。代码将使用 dset.CIFAR10 加载该数据集,并进行相应的预处理和标准化操作。

如果 opt.dataset 的取值为 'mnist',那么它将被视为 MNIST 数据集。代码将使用 dset.MNIST 加载该数据集,并进行相应的预处理和标准化操作。

如果 opt.dataset 的取值为 'fake',那么它将被视为一个虚拟数据集。代码将使用 dset.FakeData 创建一个具有指定大小的虚拟数据集。

最后,代码将根据所选择的数据集类型确定输入图像的通道数 nc。

在加载数据集之后,代码使用 torch.utils.data.DataLoader 创建一个数据加载器 dataloader,用于在训练过程中以指定的批量大小加载数据。

如果 opt.mps 为 True 并且支持 Tensor Core 加速,则使用 torch.backends.mps.is_available() 来检查是否可用,并将 use_mps 设置为相应的值。

3.4加速设备

if opt.cuda:

    device = torch.device("cuda:0")

elif use_mps:

    device = torch.device("mps")

else:

    device = torch.device("cpu")

本次实验采用cuda作为加速

3.5定义网络结构

class Generator(nn.Module):

    def __init__(self, ngpu):

        super(Generator, self).__init__()

        self.ngpu = ngpu

        self.main = nn.Sequential(

            # input is Z, going into a convolution

            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),

            nn.BatchNorm2d(ngf * 8),

            nn.ReLU(True),

            # state size. (ngf*8) x 4 x 4

            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),

            nn.BatchNorm2d(ngf * 4),

            nn.ReLU(True),

            # state size. (ngf*4) x 8 x 8

            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),

            nn.BatchNorm2d(ngf * 2),

            nn.ReLU(True),

            # state size. (ngf*2) x 16 x 16

            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),

            nn.BatchNorm2d(ngf),

            nn.ReLU(True),

            # state size. (ngf) x 32 x 32

            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),

            nn.Tanh()

            # state size. (nc) x 64 x 64

        )

    def forward(self, input):

        if input.is_cuda and self.ngpu > 1:

            output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))

        else:

            output = self.main(input)

        return output

netG = Generator(ngpu).to(device)

netG.apply(weights_init)

if opt.netG != '':

    netG.load_state_dict(torch.load(opt.netG))

print(netG)

class Discriminator(nn.Module):

    def __init__(self, ngpu):

        super(Discriminator, self).__init__()

        self.ngpu = ngpu

        self.main = nn.Sequential(

            # input is (nc) x 64 x 64

            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),

            nn.LeakyReLU(0.2, inplace=True),

            # state size. (ndf) x 32 x 32

            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),

            nn.BatchNorm2d(ndf * 2),

            nn.LeakyReLU(0.2, inplace=True),

            # state size. (ndf*2) x 16 x 16

            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),

            nn.BatchNorm2d(ndf * 4),

            nn.LeakyReLU(0.2, inplace=True),

            # state size. (ndf*4) x 8 x 8

            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),

            nn.BatchNorm2d(ndf * 8),

            nn.LeakyReLU(0.2, inplace=True),

            # state size. (ndf*8) x 4 x 4

            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),

            nn.Sigmoid()

        )

    def forward(self, input):

        if input.is_cuda and self.ngpu > 1:

            output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))

        else:

            output = self.main(input)

        return output.view(-1, 1).squeeze(1)

作用:

定义了一个 GAN 模型的生成器和判别器。生成器使用反卷积层将输入的随机噪声向量 Z 转换成一张数字图像,而判别器则是一个标准的卷积神经网络,用于判断输入的图像是否为真实的数字图像。

在生成器中,首先将随机噪声向量 Z 通过一个反卷积层转换成一个大小为 (ngf*8) x 4 x 4 的特征图。接着,通过一系列反卷积层逐步增加特征图的大小,最终得到一个大小为 (nc) x 64 x 64 的图像。其中,ngf 是生成器中特征图的通道数,nc 是图像的通道数(对于 MNIST 数据集来说,nc=1)。

在判别器中,首先将输入的图像通过一个卷积层变成一个大小为 (ndf) x 32 x 32 的特征图。接着,通过一系列卷积层逐步减小特征图的大小,最终得到一个大小为 1 的标量,表示输入的图像是否为真实的数字图像。其中,ndf 是判别器中特征图的通道数。

需要注意的是,在训练 GAN 模型时,生成器和判别器是分别训练的。具体来说,每一次迭代中,首先生成一批随机噪声向量作为生成器的输入,然后使用生成器生成一批假的数字图像。接着,将这些真实的和假的图像放入判别器中,并计算它们的损失函数。最后,根据损失函数的值更新生成器和判别器的参数。这个过程会反复进行多轮,直到生成器能够生成与真实图像相似的假图像。

3.6模型训练

if __name__ == '__main__':

    for epoch in range(opt.niter):

        for i, data in enumerate(dataloader, 0):

         ############################

        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))

        ###########################

        # train with real

            netD.zero_grad()

            real_cpu = data[0].to(device)

            batch_size = real_cpu.size(0)

            label = torch.full((batch_size,), real_label,

                            dtype=real_cpu.dtype, device=device)

            output = netD(real_cpu)

            errD_real = criterion(output, label)

            errD_real.backward()

            D_x = output.mean().item()

        # train with fake

            noise = torch.randn(batch_size, nz, 1, 1, device=device)

            fake = netG(noise)

            label.fill_(fake_label)

            output = netD(fake.detach())

            errD_fake = criterion(output, label)

            errD_fake.backward()

            D_G_z1 = output.mean().item()

            errD = errD_real + errD_fake

            optimizerD.step()

        ############################

        # (2) Update G network: maximize log(D(G(z)))

        ###########################

            netG.zero_grad()

            label.fill_(real_label)  # fake labels are real for generator cost

            output = netD(fake)

            errG = criterion(output, label)

            errG.backward()

            D_G_z2 = output.mean().item()

            optimizerG.step()

            print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'

                % (epoch, opt.niter, i, len(dataloader),

                    errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

            if i % 100 == 0:

                vutils.save_image(real_cpu,

                              '%s/real_samples.png' % opt.outf,

                              normalize=True)

                fake = netG(fixed_noise)

                vutils.save_image(fake.detach(),

                              '%s/fake_samples_epoch_%03d.png' % (opt.outf, epoch),

                              normalize=True)

            if opt.dry_run:

                break

    # do checkpointing

        torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (opt.outf, epoch))

        torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (opt.outf, epoch))

作用:

通过循环迭代的方式训练生成器(netG)和判别器(netD)。

在每个迭代中,首先遍历数据加载器(dataloader)以获取真实的图像数据。然后分别对判别器和生成器进行更新:

  1. 更新判别器(D network):

使用真实图像计算判别器的损失(errD_real),并进行反向传播(backward)和参数更新(optimizerD.step())。

使用生成器生成的假图像计算判别器的损失(errD_fake),同样进行反向传播和参数更新。

最后计算总体的判别器损失()。

  1. 更新生成器(G network):

使用生成器生成的假图像,计算判别器的输出,并据此计算生成器的损失(errG),然后进行反向传播和参数更新。

在训练过程中,会打印出当前迭代的损失值以及部分生成的图像,并且可以选择在每一定数量的迭代中保存生成的图像。

最后,根据需要进行模型参数的保存(checkpointing)。

3. 测试调参和结果分析

3.1程序运行启动流程

3.1.1训练模型(例):

python main.py --dataset mnist --dataroot C:\Users\MSI-NB\Desktop\code --batchSize 64 --niter 1 –lr 0.0002  --cuda

运行该命令会先下载数据集(未下载),按要求存放在指定路径,然后开始训练模型,输出误差及其模型评价指标。

3.1.2结果(例):

          

                           result 1样例生成                           test 1样图例子

    

Loss_D: 0.1867 Loss_G: 2.8261 D(x): 0.9044 D(G(z)): 0.0758 / 0.0821

Loss_D: 判别器的损失,用于衡量判别器对真实图像和生成图像的识别能力。

Loss_G: 生成器的损失,表示生成器在欺骗判别器方面的表现。

D(x): 表示判别器对真实图像的平均输出。在这里,值接近 1 表明判别器在真实图像上的表现良好。

D(G(z)): 表示判别器对生成图像的平均输出。值接近 0 表明判别器被生成器成功“欺骗”。

3.1.3对抗模型

每轮模型训练结束会生成一个判别器(netD)模型(用于判别生成图片和真实图片)和生成器(netG)模型(更加真实的趋向于生成一个能够使判别器模型能力下降的一个对抗模型)

3.2对比调参分析

(实验评价指标采取生成图像,和各类误差及其相似指标进行对比分析)

(不是测试epochs时,只训练一轮,只有一个对比生成图像)

3.2.1 测试batchsize

1不同batchsize下的各个指标

batchsize

Loss_D

Loss_G

D(x)

D(G(z))

32

0.3514

3.1977

0.7631

0.8112

64

0.1867

2.8261

0.9044

0.9232

128

0.1247

3.1784

0.9147

0.4762

*mnist数据集结果

2不同batchsize下的结果

batchsize

测试样例(原图例和生成图)

32

64

128

result 2不同batchsize下的结果

分析:

随着 batchsize 的增加,Loss_D 和 D(x) 都在下降,而 Loss_G 和 D(G(z)) 则在上升。

当 batch size 从 32 增加到 128 时,判别器的损失(Loss_D)逐渐减小,生成器的损失(Loss_G)逐渐增加。

判别器对真实数据的判断准确度(D(x))逐渐增加,而对生成数据的判断准确度(D(G(z)))则逐渐降低。

不同的batchsize对对抗模型有极大的影响,无论是生成器还是判别器都有性能因素在其中。

较大的 batch size 可能有助于提高训练的稳定性和模型的泛化能力,因为它使得每次更新所使用的样本数量更多。

但同时也可能导致生成器训练变得更加困难,因为生成器需要更长的时间来适应更大的 batch size,并且可能会导致模式崩溃或其他训练不稳定的问题。

3.2.2测试epoch num

result 3不同epoch下的结果1mnist

因为程序跑出来的原图不好分辨,裁剪图如下:

result 4不同epoch下的结果2mnist

result 5不同epoch下的结果1Cifar10

同理:

result 6不同epoch下的结果2Cifar10

3.2.3测试学习率

3不同lr下的各个指标

Learning_rate

Loss_D

Loss_G

D(x)

D(G(z))

0.0001

0.0012

8.0418

0.9999

2.75

0.00015

0.2597

4.0671

0.9447

5.72

0.0002

0.1247

3.1784

0.9147

0.4762

0.00025

0.5013

1.9783

0.6533

0.1776

*mnist数据集结果

4不同lr下的结果

Learning_rate

测试样例(原图例和生成图)

0.0001

0.00015

0.0002

见上3.2.1中batchsize=128

0.00025

result 7不同学习率下的结果

分析:

随着学习率从 0.0001 增加到 0.00025,生成器的损失(Loss_G)逐渐减小,而鉴别器的损失(Loss_D)则呈现不规律的波动。

随着学习率的增加,鉴别器对真实数据的判断准确度(D(x))似乎出现了一定程度的下降,而对生成数据的判断准确度(D(G(z)))则呈现不规则的变化。

基于以上观察,可以得出以下结论:

学习率的选择对于 GAN 模型的训练至关重要。过大或者过小的学习率都可能导致训练不稳定或者收敛缓慢。

调参中,学习率为 0.0002 时显示出相对较好的结果,但仍然需要进一步实验来寻找最佳的学习率。

因此,针对不同的数据集和模型架构,需要通过实验来调整学习率,以找到最适合的值,从而获得更好的训练效果。

5.问题和解决方法

6.总结和思考

Batch size 的选择对于训练的稳定性和生成器的效果至关重要。较大的 batch size 可能有助于提高模型的泛化能力,但同时也可能导致生成器训练困难和训练不稳定。

学习率的选择对于 GAN 模型的训练同样非常重要。过大或过小的学习率都可能导致训练不稳定或收敛缓慢。

监控损失和评估指标的变化是评估模型训练进展的重要手段。通过观察这些指标的变化,可以了解模型在训练过程中的表现,并作出相应的调整。

实验使用了GPU对训练模型进行加速,提高了实验效率。进行对比调参,了解到每个参数对模型

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

恶心猫charming

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值