GAN 对抗网络
GAN(Generative Adversarial Network)对抗网络指的是神经网络中包括两个子网络,一个用于生成信息,一个用于验证信息。下面的例子是生成图片的对抗网络,一个网络用于生成图片,另一个网络用于验证。G 网络用于生成图片,不断的学习并生成更接近于训练数据的图像,D 网络用于鉴别图片,通过学习更准确的识别出图片的真假,最终通过学习让 G 网络能够生成高质量的目标图片。下面通过代码实现两种不同的 GAN,图片为自动生成手写图片,采用 MNIST数据集。
- DCGAN (Deep Convolution)深度卷积生成对抗网络
- SAGAN(Self Attention)自注意力生成对抗网络
安装依赖
本文将使用 sklearn,首先安装 sklearn。
pip install -U scikit-learn
DCGAN 深度卷积对抗网络
数据准备
采用 MNIST 数据,并只选用 7、8 两个数字,MINST 中 7、8 数字各有 200 张。
def make_datapath_list():
"""创建用于学习和验证的图像数据及标注数据的文件路径列表。 """
train_img_list = list() #保存图像文件的路径
for img_idx in range(200):
img_path = "./data/img_78/img_7_" + str(img_idx)+'.jpg'
train_img_list.append(img_path)
img_path = "./data/img_78/img_8_" + str(img_idx)+'.jpg'
train_img_list.append(img_path)
return train_img_list
class ImageTransform():
"""图像的预处理类"""
def __init__(self, mean, std):
self.data_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean, std)
])
def __call__(self, img):
return self.data_transform(img)
class GAN_Img_Dataset(data.Dataset):
"""图像的 Dataset 类,继承自 PyTorchd 的 Dataset 类"""
def __init__(self, file_list, transform):
self.file_list = file_list
self.transform = transform
def __len__(self):
'''返回图像的张数'''
return len(self.file_list)
def __getitem__(self, index):
'''获取经过预处理后的图像的张量格式的数据'''
img_path = self.file_list[index]
img = Image.open(img_path) #[ 高度 ][ 宽度 ] 黑白
#图像的预处理
img_transformed = self.transform(img)
return img_transformed
#创建DataLoader并确认执行结果
#创建文件列表
train_img_list=make_datapath_list()
#创建Dataset
mean = (0.5,)
std = (0.5,)
train_dataset = GAN_Img_Dataset(
file_list=train_img_list, transform=ImageTransform(mean, std))
#创建DataLoader
batch_size = 64
train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, shuffle=True)
#确认执行结果
batch_iterator = iter(train_dataloader) #转换为迭代器
imges = next(batch_iterator) #取出位于第一位的元素
print(imges.size()) # torch.Size([64, 1, 64, 64])
生成网络实现
需要根据输入的随机数生成图像,对数据的维度进行放大,并增加维度中的元素数量,通过 nn.ConvTranspose2d 转置卷积进行实现。转置卷积是卷积的反向操作,卷积输出特征通常比输入数据小,反向卷积输出比输入大,可以看做数据放大操作。
# 导入软件包
import random
import math
import time
import pandas as pd
import numpy as np
from PIL import Image
import torch
import torch.utils.data as data
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms
# Setup seeds
torch.manual_seed(1234)
np.random.seed(1234)
random.seed(1234)
class Generator(nn.Module):
def __init__(self, z_dim=20, image_size=64):
super(Generator, self).__init__()
self.layer1 = nn.Sequential(
nn.ConvTranspose2d(z_dim, image_size * 8,
kernel_size=4, stride=1),
nn.BatchNorm2d(image_size * 8),
nn.ReLU(inplace=True))
self.layer2 = nn.Sequential(
nn.ConvTranspose2d(image_size * 8, image_size * 4,
kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(image_size * 4),
nn.ReLU(inplace=True))
self.layer3 = nn.Sequential(
nn.ConvTranspose2d(image_size * 4, image_size * 2,
kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(image_size * 2),
nn.ReLU(inplace=True))
self.layer4 = nn.Sequential(
nn.ConvTranspose2d(image_size * 2, image_size,
kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(image_size),
nn.ReLU(inplace=True))
self.last = nn.Sequential(
nn.ConvTranspose2d(image_size, 1, kernel_size=4,
stride=2, padding=1),
nn.Tanh())
# 注意:因为是黑白图像,所以只有一个输出通道
def forward(self, z):
out = self.layer1(z)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = self.last(out)
return out
根据输入,生成图片
#动作确认
import matplotlib.pyplot as plt
%matplotlib inline
G = Generator(z_dim=20, image_size=64)
# 输入的随机数
input_z = torch.randn(1, 20)
# 将张量尺寸变形为(1,20,1,1)
input_z = input_z.view(input_z.size(0), input_z.size(1), 1, 1)
#输出假图像
fake_images = G(input_z)
img_transformed = fake_images[0][0].detach().numpy()
plt.imshow(img_transformed, 'gray')
plt.show()
没有经过学习生成的图片,目标是通过学习生成手写数字的效果。


最低0.47元/天 解锁文章
3177

被折叠的 条评论
为什么被折叠?



