pytorch tutorial -- DCGAN

本文详细介绍了DCGAN(深度卷积生成对抗网络)的工作原理,包括生成器和判别器的角色。训练过程中,生成器试图生成逼真的假图像,而判别器则试图区分真实图像和生成的假图像。随着训练的进行,最终目标是生成器能产生与训练数据分布相同的完美假图像,判别器判断真假的概率为50%。文章还涵盖了网络定义、损失函数、训练策略以及可视化技巧。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

  • About GAN

    1. GANs is a framework for teaching DL model to capture data’s distribution so we can generate new data from that same distribution.
    2. The job of the generator is to spawn ‘fake’ images that look like the training images.
    3. The job of the discriminator is to look at an image and output whether or not it is a real training image or a fake image from the generator.
    4. The equilibrium of this game is when the generator is generating perfect fakes that look as if they came directly from the training data, and the discriminator is left to always guess at 50% confidence that the generator output is real or fake.
    5. D(x) is the discriminator network which outputs the (scalar) probability that x came from training data rather than the generator. D(x) can also be thought of as a traditional binary classifier.
    6. (z) represents the generator function which maps the latent vector z to data-space. The goal of G is to estimate the distribution that the training data comes from (pdata) so it can generate fake samples from that estimated distribution (pg).
    7. G play a minimax game in which D tries to maximize the probability it correctly classifies reals and fakes (logD(x)), and G tries to minimize the probability that D will predict its outputs are fake (log(1−D(G(x)))).
    8. In theory, the solution to this minimax game is where pg=pdata, and the discriminator guesses randomly if the inputs are real or fake. However, the convergence theory of GANs is still being actively researched and in reality models do not always train to this point.
  • define network – generator and discriminator

    1. import torch.nn as nn
      notice: layer function, ks,stride,padding calculation
      best to do: input size -> state size -> … -> state size -> output size
      notice: inplace manipulation, nn.ReLU(True), nn.Tanh(), nn.Sigmoid(), nn.LeakyReLU(0.2, inplace=True)
    2. instantial and remove to device, weight initialization
      net.apply(weights_init), initial weights for specific layers
      weights_init function, parameter is model itself
    3. instantial optimizer
  • loss function

    1. define: criterion=nn.BCELoss()
    2. call:
  • data procession
    create dataset -> dataloader -> read -> visualize

    1. import torchvision.datasets as dset dset.ImageFolder
    2. import torchvison.utils.data torchvision.utils.data.DataLoader
    3. next(iter(dataloader))
    4. import matplotlib.pyplot as plt plt.figure axis title imshow
  • parameters to argparse

    1. At first, complete logics of program, throw parameters at the beginning of the program
    2. Finally, convert parameters in argparse format
  • trainning
    Be mindful that training GANs is somewhat of an art form, as incorrect hyperparameter settings lead to mode collapse with little explanation of what went wrong. Here, we will closely follow Algorithm 1 from Goodfellow’s paper, while abiding by some of the best practices shown in ganhacks.

    • Discriminator:
      “update the discriminator by ascending its stochastic gradient”, separate mini-batch
    1. Practically, we want to maximize log(D(x))+log(1−D(G(z))).
    2. First, we will construct a batch of real samples from the training set, forward pass through D, calculate the loss (log(D(x))), then calculate the gradients in a backward pass.
    3. Secondly, we will construct a batch of fake samples with the current generator, forward pass this batch through D, calculate the loss (log(1−D(G(z)))), and accumulate the gradients with a backward pass.
    4. Now, with the gradients accumulated from both the all-real and all-fake batches, we call a step of the Discriminator’s optimizer.
    • Generator:
    1. train the Generator by minimizing log(1−D(G(z))) in an effort to generate better fakes. As mentioned, this was shown by Goodfellow to not provide sufficient gradients, especially early in the learning process.
    2. As a fix, we instead wish to maximize log(D(G(z))). In the code we accomplish this by: classifying the Generator output from Part 1 with the Discriminator, computing G’s loss using real labels as GT, computing G’s gradients in a backward pass, and finally updating G’s parameters with an optimizer step. It may seem counter-intuitive to use the real labels as GT labels for the loss function, but this allows us to use the log(x) part of the BCELoss (rather than the log(1−x) part) which is exactly what we want.
    • About code
    1. track parameters list
    2. enumerate来批量读取dataloader
    3. 梯度置零,数据放cuda,数据传入网络,设置label向量,
      criterion计算loss,loss回传,计算对fake和real的概率,更新优化器
      记录loss,周期性显示生成图片
    4. backward顺序:errD_real -> errD_fake (fake.detach()) -> errG
  • Visualization

    1. First, we will see how D and G’s losses changed during training.
      plt.figure, title, plot, xlabel, ylabel, legend, show
      two losses both converge to 0.5 as G gets better
    2. Second, we will visualize G’s output on the fixed_noise batch for every epoch.
      import matplotlib.animation as animation animation.ArtistAnimation
      from IPython.display import HTML HTML(ani.to_jshtml()) jupyter中的内联动画
    3. And third, we will look at a batch of real data next to a batch of fake data from G.
      real_batch = next(iter(dataloader))
      plt.figure(figsize=(15,15)) -> plt.subplot(1,2,1) -> plt.axies(‘off’), plt.title(‘real or fake’) attribute -> plt.imshow -> plt.show
      notice: np.transpose(vutils.make_grid(real_batch[0].to(device)[:64],padding=5,normalize=True).cpu(),(1,2,0)))
  • Where to Go Next?
    We have reached the end of our journey, but there are several places you could go from here. You could:

    1. Train for longer to see how good the results get
    2. Modify this model to take a different dataset and possibly change the size of the images and the model architecture
    3. Check out some other cool GAN projects here
    4. Create GANs that generate music
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值