G3 - 手势图像生成 CGAN入门



理论知识

CGAN(条件生成对抗网络)是在GAN(生成对抗网络)的基础上进行了一些改进。对于原始的GAN生成器而言,用来生成图像的数据是随机不可预测的,因此没有办法控制网络的输出,在实际操作中的可控性不强。

针对原始GAN无法生成具有特定属性的图像数据的问题,Mehdi Mirza等人在2014年提出了CGAN,通过给原始GAN中的生成器G和判别器D增加额外的条件,来把无监督学习的GAN转化为有监督学习的CGAN,便于网络能够在我们的掌控下更好地进行训练。

例如:我们需要生成器G生成一张没有阴影的图像,此时判别器D就需要判断生成器所生成的图像是否是一张没有阴影的图像。

CGAN的本质就是将额外的信息融入到生成器和判别器中,其中添加的信息可以是图像的类别 ,人脸表情和其他辅助信息等。

网络结构如图所示:
CGAN网络结构
从图中的网络结构可知,条件信息y作为额外的输入被引入到GAN中,与生成器中的噪声z合并作为隐含层的表达;而在判别器D中,条件信息y则与原始数据x合并作为判别函数的输入。这种改进在以后的许多研究中被证明是非常有效的,为后续的相关工作提供了积极的指导作用。

环境

  • Python 3.11
  • GTX 4090
  • Pytorch 2.1.0

步骤

环境设置

包引用

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import make_grid, save_image
from torch.utils.data import DataLoader
from torchinfo import summary
import matplotlib.pyplot as plt

创建一个全局的设备对象,和批次大小

# 设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 批次大小
batch_size = 128

数据准备

导入数据

transform = transforms.Compose([
	transforms.Resize(128),
	transforms.ToTensor(),
	transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

train_dataset = datasets.ImageFolder(
	root='data/rps', transform=transform)

train_loader = DataLoader(dataset=train_dataset,
						batch_size=batch_size,
						shuffle=True,
						num_workers=6)

查看数据集中的数据

def show_images(images):
	"""把图像组合成一个网络,并展示"""
	plt.figure(figsize=(20, 20)
	plt.axis('off')
	plt.imshow(make_grid(images.detach(), nrow=22).permute(1, 2, 0))

def show_batch(dl):
	"""在数据库中取一个批次的数据进行展示"""
	for images, _ in dl:
		show_images(images)
		break

show_batch(train_loader)

数据集展示

模型设计

首先设置一下模型输入输出 ,隐藏层的参数

# 图像的形状
image_shape = (3, 128, 128)
# 扯平后的维度
image_dim = int(np.prod(image_shape))
# 隐藏层的维度
latent_dim = 100

# 分类数量 剪刀 石头 布
n_classes = 3
# 嵌入维度
embedding_dim = 100

编写模型的初始化函数

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight, 1.0, 0.02)
        nn.init.zeros_(m.bias)

构建生成器

class Generator(nn.Module):
	def __init__(self):
		super().__init__()
		# 条件标签生成器,用来将标签映射到嵌入空间中
		self.label_conditioned_generator = nn.Sequential(
			nn.Embedding(n_classes, embedding_dim), # 使用Embedding层,将条件标签映射为稠密向量
			nn.Linear(embedding_dim, 16) # 使用线性层将稠密向量转换为更好维度
		)
		# 潜在向量生成器,用于将噪声向量映射到图像空间中
		self.latent = nn.Sequential(
			nn.Linear(latent_dim, 4*4*512), # 使用线性层将潜在向量转换为更高维度
			nn.LeakyReLU(0.2, inplace=True)
		)
		# 生成器的主要结构,将条件标签和潜在向量合并,生成图像
		self.model = nn.Sequential(
			# 反卷积层1:将合并后的向量映射为64*8*8的特征图
			nn.ConvTranspose2d(513, 64*8, 4, 2, 1, bias=False),
			nn.BatchNorm2d(64*8,<
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值