基于pytorch的dcgan代码实现,进行简易图像数据生成

本文介绍如何使用DCGAN在资源受限的个人电脑上,通过numpy生成1维图像数据,仅用50轮训练就实现肉眼难辨的图像生成。详细步骤包括数据集创建、网络构建、参数设置和训练过程,最终展示了生成的真假图像对比。
简易实现使用dcgan进行图像数据生成

前言:个人电脑算力有限(2G现显存GPU),现使用numpy自生成1维图像数据,dcgan生成对抗网络训练50轮(花费不到20min),生成fake image已肉眼难分。

生成结果如下

在这里插入图片描述

  1. 导入

    import torchvision
    import torchvision.transforms as transform
    from torchvision.datasets import ImageFolder
    import torchvision.utils as vutils
    
    import torch
    import torch.nn as nn
    from torch.utils.data import DataLoader
    from torch.nn import BCELoss
    from torch.optim import Adam,SGD
    
    
    import matplotlib.pyplot as plt
    from matplotlib import animation # 用于生成gif图像
    import numpy as np
    
  2. 训练参数设置

    # 参数设置
    images_num = 3000 # 生成训练图像数量
    image_size = 32 # 生成图像尺寸
    image_channel = 1 # 生成图像通道数
    
    batch_size = 128 # 训练数据批次
    noise_size = 100 # 随机噪声分布向量长度,用于生成fake image
    
    lr = 0.0002 # 学习率
    epochs = 50 # 训练轮次
    device =  torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
  3. 批量生成图像训练集

    # 1 生成简易图像数据集
    def gen_images():
        
        image = np.ones(shape=(image_size,image_size))-np.random.rand(image_size,image_size)/10
        min_x = np.random.choice(range(3,8))
        max_x = np.random.choice(range(25,30))
        values = np.random.rand(1)/10
        for i in range(min_x,max_x):
            row =i;col= int(np.power((i/max_x-0.55)*2,2)*max_x)+3
            image[col,row]=values;image[col+2,row]=values;image[col+2,row]=values
        return image
    
    # 创建图像数据集
    class Images:
        
        def __init__(self):
            self.len_ = images_num
        
        def __getitem__(self,index):
            return torch.from_numpy(gen_images()).unsqueeze(dim=0).float(),''
        
        def __len__(self):
            return self.len_
    
    # 生产批数据图像  
    images = Images()
    dataloader = DataLoader(dataset=images,batch_size=batch_size,shuffle=True,drop_last=True)
    
    # 生成图像显示
    real_batch = next(iter(dataloader))
    fig,axs = plt.subplots()
    fig.set_size_inches(w=12,h=8)
    axs.imshow(vutils.make_grid(tensor=real_batch[0][:64],nrow=8,padding
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值