【GANs入门】pytorch-GANs任务迁移-单个目标(数字的生成)

简述

之前认真学习了网上的一份,代码做了很详细的笔记。
【Gans入门】Pytorch实现Gans代码详解【70+代码】

但是上面的任务只是画一条在一定区间下的曲线。
这里对这个进行迁移,到可以进行图像的生成。

图像的很多数据都没有,但是突然想到在sklearn上的digits是一个非常简单的图片。
这里我想到之前的一份笔记
sklearn学习(一)

这里会使用sklearn自带的小数据来做训练
目标是让神经网络自己学会生成数字。

任务描述

为了让神经网络操作更简单。这里的输入数据只会选择特定数值的数字图片数据。然后丢给对抗生成神经网络学习。让其中的生成器学会如何生成手写数字。

下面是选择用数值1的生成过程

其实可以发现其实是有点这样的感觉了。

在这里插入图片描述

下面的这个是让它学习数字0的效果

可能是由于数字0的细节更粗糙一点,所以,可以发现,我们认为这个0生成的更好。(数字1和数字4其实是有点像的,所以会有点问题,还有这是因为图片像素有点低

在这里插入图片描述

代码详解

导入包

  • torch,numpy这些都是数据处理过程中需要的包
  • matplotlib为了画图
  • sklearn主要是为了它本身带的数据
  • random主要是为了选择标准数据更具有随机性
  • os,shutil,imageio这三个库是为了画出gif动态图
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
import random
import os
import shutil
import imageio

创建临时文件夹

PNGFILE = './png/'
if not os.path.exists(PNGFILE):
    os.mkdir(PNGFILE)
else:
    shutil.rmtree(PNGFILE)
    os.mkdir(PNGFILE)

这里会创建一个临时的文件夹png,会把中途生成的那些图片都存在这,然后我就可以用这些png来生成gif文件

模型参数

  • BATCH_SIZE这个参数表示每次用多少的数据来进行考量。(数值多的话模型进化的会稍微快点)
  • LR_GLR_D表示两个模型的学习率
  • N_IDEAS:启发式因子(生成函数的初始层的节点数)。因为我们要操作的节点数量会特别大(特别是图像问题,但是如果输入节点过于大的话,会需要大量的计算资源。所以用小一点的这个基本够用就行了)
  • target_num :表示的是想要生成的数字。由于数据集中只有(0到9)所以,这里也只能取0到9。
  • image_max表示图片像素点的最大值,这个一开始我用到了,但是后来我修改了代码之后,就用不到了。
  • ART_COMPONENTS:像素点数量(其实本质上跟前一个版本的参考节点数都是一样的)
# Hyper Parameters
BATCH_SIZE = 64
LR_G = 0.00001  # learning rate for generator
LR_D = 0.00001  # learning rate for discriminator
N_IDEAS = 6  # think of this as number of ideas for generating an art work (Generator)
target_num = 0 # target Number

digits = datasets.load_digits()
target = digits.target
data = digits.data[target == target_num]
image_max = max(data.reshape((-1,)))
ART_COMPONENTS = data.shape[-1]  # it could be total point G can draw in the canvas

标准数据

这个函数本质上,这个区间上选BATCH_SIZE个标准数据。
但是,random.sample只能输入的是list所以需要先把data转成list,但是转出来的list又不能直接变成torch中的Tensor,这里需要再转成ndarray,之后再转成Tensor,但是要注意在后面加一个.float()函数的操作。

def artist_works():  # painting from the famous artist (real target)
    return torch.from_numpy(np.array(random.sample(list(data),
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

gc.collect()

公众号“肥宅Sean”欢迎关注

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值