GAN 是一种常用的优秀的图像生成模型。我们使用了支持条件生成的 cGAN。下面介绍简单 cGAN 模型的构建以及训练过程。
2.1 在 model 文件夹中新建 nets.py 文件
import torch
import torch.nn as nn
# 生成器类
class Generator(nn.Module):
def __init__(self, nz=100, nc=3, ngf=128, num_classes=4):
super(Generator, self).__init__()
self.label_emb = nn.Embedding(num_classes, nz)
self.main = nn.Sequential(
nn.ConvTranspose2d(nz + nz, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
nn.Tanh()
)
def forward(self, z, labels):
c = self.label_emb(labels).unsqueeze(2).unsqueeze(3)
x = torch.cat([z, c], 1)
return self.main(x)
# 判别器类
class Discriminator(nn.Module):
def __init__(self, nc=3, ndf=64, num_classes=4):
super(Discriminator, self).__init__()
self.label_emb = nn.Embedding(num_classes, nc * 64 * 64)
self.main = nn.Sequential(
nn.Conv2d(nc + 1, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf * 4, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, img, labels):
c = self.label_emb(labels).view(labels.size(0), 1, 64, 64)
x = torch.cat([img, c], 1)
return self.main(x)
2.2新建cGAN_net.py
import torch
import torch.nn as nn
from torch.optim import Adam
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import StepLR
# ===========================
# Conditional DCGAN 实现
# ===========================
class cDCGAN:
def __init__(self, data_root, batch_size, device, latent_dim=100, num_classes=4):
self.device = device
self.batch_size = batch_size
self.latent_dim = latent_dim
self.num_classes = num_classes
# 数据加载器
self.train_loader = self.get_dataloader(data_root)
# 初始化生成器和判别器
self.generator = self.build_generator().to(device)
self.discriminator = self.build_discriminator().to(device)
# 初始化权重
self.generator.apply(self.weights_init)
self.discriminator.apply(self.weights_init)
# 损失函数和优化器
self.criterion = nn.BCELoss()
self.optimizer_G = Adam(self.generator.parameters(), lr=0.0001, betas=(0.5, 0.999))
self.optimizer_D = Adam(self.discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999))
# 学习率调度器
self.scheduler_G = StepLR(self.optimizer_G, step_size=10, gamma=0.5) # 每10个epoch学习率减半
self.scheduler