在这篇博客中,我将分享我在使用生成对抗网络(GAN)进行图像生成方面的一些实验成果。这些实验基于CIFAR-100数据集,旨在通过深度学习模型生成高质量的图像,并探索模型训练过程中的一些关键技术与优化方法。
代码已经资源绑定可以自行参考下载
背景介绍
生成对抗网络(GAN)是一种强大的深度学习框架,由两部分组成:生成器和判别器。生成器负责生成尽可能真实的图像,而判别器的任务是判断这些图像是否真实。在训练过程中,生成器和判别器通过博弈来提升各自的能力,最终达到生成器能够生成高质量图像的效果。
在本实验中,我选择了CIFAR-100数据集,该数据集包含100个类别的彩色图像,每个图像的尺寸为32x32。我在此基础上进行了一系列的模型优化与技术探索,旨在生成具有较高视觉质量的图像。
模型架构
-
生成器(Generator):生成器的任务是根据一个随机噪声(latent vector)生成图像。为了增强生成图像的质量,我设计了一些创新的技术,包括条件批归一化(Conditional BatchNorm2d)和自注意力机制(Self Attention)等。此外,我还使用了交叉注意力机制(Cross Attention)将标签信息与图像特征进行融合,以提高模型在类别生成上的表现。
-
判别器(Discriminator):判别器用于判断输入图像是来自真实数据集还是由生成器生成的。通过使用残差块(ResBlockD)和谱归一化(Spectral Normalization)等技术,判别器能够在训练中更好地学习和判别图像的真实性。
实验代码
import os
import shutil
import torch
import time
import torch.nn as nn
import torch.optim as optim
import torchvision.utils as vutils
import torchvision.transforms as transforms
from torchvision import datasets
from torchvision.models import inception_v3
from torch.utils.data import DataLoader
import torch.nn.functional as F
import numpy as np
import torch.nn.utils.spectral_norm as spectral_norm # 导入谱归一化
# 超参数
latent_dim = 128
num_classes = 100 # CIFAR-100 类别数
img_size = 32
channels = 3
batch_size = 128
n_epochs = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 数据预处理
transform = transforms.Compose([
transforms.RandomHorizontalFlip(p=0.5), # 随机水平翻转
transforms.Resize(img_size),
transforms.ToTensor(),
transforms.Normalize([0.5]*channels, [0.5]*channels),
])
# 加载数据集
train_dataset = datasets.CIFAR100(root='data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
# 自注意力层
class SelfAttention(nn.Module):
def __init__(self, in_dim):
super(SelfAttention, self).__init__()
self.channel_in = in_dim
self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
m_batchsize, C, width, height = x.size()
proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1)
proj_key = self.key_conv(x).view(m_batchsize, -1, width * height)
energy = torch.bmm(proj_query, proj_key)
attention = self.softmax(energy)
proj_value = self.value_conv(x).view(m_batchsize, -1, width * height)
out = torch.bmm(proj_value, attention.permute(0, 2, 1))
out = out.view(m_batchsize, C, width, height)
out = self.gamma * out + x
return out
# 条件批归一化
class ConditionalBatchNorm2d(nn.Module):
def __init__(self, num_features, num_classes):
super(ConditionalBatchNorm2d, self).__init__()
self.bn = nn.BatchNorm2d(num_features, affine=False)
self.embed = nn.Embedding(num_classes, num_features * 2)
# 初始化gamma为1,beta为0
self.embed.weight.data[:, :num_features].fill_(1)
self.embed.weight.data[:, num_features:].zero_()
def forward(self, x, y):
out = self.bn(x)
gamma, beta = self.embed(y).chunk(2, 1)
gamma = gamma.unsqueeze(2).unsqueeze(3)
beta = beta.unsqueeze(2).unsqueeze(3)
out = gamma * out + beta
return out
# 生成器的残差块
class ResBlockG(nn.Module):
def __init__(self, in_channels, out_channels, num_classes):
super(ResBlockG, self).__init__()
self.bn1 = ConditionalBatchNorm2d(in_channels, num_classes)
self.activation = nn.ReLU(inplace=True)
self.conv1 = nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1)
self.bn2 = ConditionalBatchNorm2d(out_channels, num_classes)
self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1)
self.shortcut = nn.Sequential()
if in_channels != out_channels or True:
self.shortcut = nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1),
nn.BatchNorm2d(out_channels)
)
def forward(self, x, labels):
out = self.bn1(x, labels)
out = self.activation(out)
out = self.conv1(out)
out = self.bn2(out, labels)
out = self.activation(out)
out = self.conv2(out)
shortcut = self.shortcut(x)
out += shortcut
return out
# 判别器的残差块
class ResBlockD(nn.Module):
def __init__(self, in_channels, out_channels):
super(ResBlockD, self).__init__()
self.conv1 = spectral_norm(nn.Conv2d(in_channels, out_channels, 3, 1, 1))
self.activation = nn.LeakyReLU(0.2, inplace=True)
self.conv2 = spectral_norm(nn.Conv2d(out_channels, out_channels, 4, 2, 1))
self.shortcut = spectral_norm(nn.Conv2d(in_channels, out_channels, 4, 2, 1)) \
if in_channels != out_channels else nn.Identity()
def forward(self, x):
out = self.conv1(x)
out = self.activation(out)
out = self.conv2(out)
shortcut = self.shortcut(x)
out += shortcut
return out
# 交叉注意力机制(Cross Attention)
class CrossAttention(nn.Module):
def __init__(self, in_dim):
super(CrossAttention, self).__init__()
# 将图像特征映射到Q
self.query_conv = nn.Conv2d(in_dim, in_dim // 8, 1)
# 将类嵌入映射到K和V
self.key_linear = nn.Linear(in_dim, in_dim // 8)
self.value_linear = nn.Linear(in_dim, in_dim)
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, class_emb):
# x: B,C,H,W
# class_emb: B,C(C与x的C相同,以便融合)
B, C, H, W = x.size()
Q = self.query_conv(x).view(B, -1, H*W).permute(0,2,1) # B,HW,C'
K = self.key_linear(class_emb).unsqueeze(1) # B,1,C'
# 计算注意力权重
energy = torch.bmm(Q, K.transpose(1,2)) # B,HW,1
attention = self.softmax(energy)
V = self.value_linear(class_emb).unsqueeze(1) # B,1,C
out = torch.bmm(attention, V) # B,HW,C
out = out.permute(0,2,1).view(B,C,H,W)
out = self.gamma * out + x
return out
# 生成器
class Generator(nn.Module):
def __init__(self, latent_dim, num_classes, channels):
super(Generator, self).__init__()
self.latent_dim = latent_dim
self.num_classes = num_classes
img_size = 32 # 固定图像尺寸为32
self.init_size = img_size // 16 # init_size = 2
self.l1 = nn.Linear(latent_dim, 512 * self.init_size ** 2)
self.res_blocks = nn.ModuleList([
ResBlockG(512, 512, num_classes),
ResBlockG(512, 256, num_classes),
ResBlockG(256, 128, num_classes),
ResBlockG(128, 64, num_classes)
])
self.attention = SelfAttention(64)
# 为交叉注意力准备的类嵌入(与64通道匹配)
self.class_embed_ca = nn.Embedding(num_classes, 64)
self.cross_attention = CrossAttention(64)
# 融合特征的卷积层(将中间层特征进行上采样后与最终特征融合)
self.fusion_conv = nn.Conv2d(256+64, 64, 1, 1, 0)
self.bn = ConditionalBatchNorm2d(64, num_classes)
self.activation = nn.ReLU(inplace=True)
self.conv_out = nn.Conv2d(64, channels, 3, 1, 1)
def forward(self, noise, labels):
out = self.l1(noise)
out = out.view(out.size(0), 512, self.init_size, self.init_size)
# 多级特征生成
out1 = self.res_blocks[0](out, labels) # 512通道, 尺寸4x4
out2 = self.res_blocks[1](out1, labels) # 256通道, 尺寸8x8
out3 = self.res_blocks[2](out2, labels) # 128通道, 尺寸16x16
out4 = self.res_blocks[3](out3, labels) # 64通道, 尺寸32x32
out = out4
# 自注意力
out = self.attention(out)
# 交叉注意力融合类特征
class_emb = self.class_embed_ca(labels) # B,64
out = self.cross_attention(out, class_emb)
# 特征融合(将out2上采样到out相同大小,然后拼接)
# out2是8x8, out是32x32
out2_upsampled = F.interpolate(out2, size=out.shape[2:], mode='nearest')
fused = torch.cat([out2_upsampled, out], dim=1) # B,(256+64),32,32
out = self.fusion_conv(fused) # B,64,32,32
out = self.bn(out, labels)
out = self.activation(out)
img = torch.tanh(self.conv_out(out))
return img
# 判别器
class Discriminator(nn.Module):
def __init__(self, num_classes, channels):
super(Discriminator, self).__init__()
self.num_classes = num_classes
self.initial = nn.Sequential(
spectral_norm(nn.Conv2d(channels, 64, 3, 1, 1)),
nn.LeakyReLU(0.2, inplace=True)
)
self.res_blocks = nn.ModuleList([
ResBlockD(64, 128),
ResBlockD(128, 256),
ResBlockD(256, 512)
])
self.attention = SelfAttention(512)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = spectral_norm(nn.Linear(512, 1))
self.embed = nn.Embedding(num_classes, 512)
def forward(self, img, labels):
out = self.initial(img)
for res_block in self.res_blocks:
out = res_block(out)
out = self.attention(out)
out = self.avg_pool(out)
out = out.view(out.size(0), -1)
validity = self.fc(out)
embed = self.embed(labels)
prod = torch.sum(out * embed, dim=1, keepdim=True)
return validity + prod
# 初始化模型
generator = Generator(latent_dim, num_classes, channels).to(device)
discriminator = Discriminator(num_classes, channels).to(device)
# 损失函数和优化器
adversarial_loss = nn.BCEWithLogitsLoss()
# 优化器
optimizer_G = optim.RMSprop(generator.parameters(), lr=0.0002, weight_decay=1e-3)
optimizer_D = optim.RMSprop(discriminator.parameters(), lr=0.0002, weight_decay=1e-3)
# MultiStepLR 定义调度器
milestones = [21,40,70] # 在这些周期降低学习率
gamma = 0.5 # 每次降低到原来的 50%
scheduler_G = optim.lr_scheduler.MultiStepLR(optimizer_G, milestones=milestones, gamma=gamma)
scheduler_D = optim.lr_scheduler.MultiStepLR(optimizer_D, milestones=milestones, gamma=gamma)
# 定义用于FID计算的文件夹路径
fid_real_images_path = 'fid_images/realbiggan_images' # 保存真实图片的路径
fid_generated_images_path = 'fid_images/generatedbiggan_images' # 保存生成图片的路径
model_save_path = 'two_model.pth' # 保存最佳模型的路径
# 删除并重新创建文件夹,以清空旧图片
for path in [fid_real_images_path, fid_generated_images_path]:
if os.path.exists(path):
shutil.rmtree(path)
os.makedirs(path, exist_ok=True)
# 加载预训练的 InceptionV3 模型
inception_model = inception_v3(weights="DEFAULT").to(device)
inception_model.eval()
# 定义Inception Score计算函数
def calculate_inception_score(images, batch_size=32, splits=10):
images = F.interpolate(images, size=(299, 299), mode="bilinear", align_corners=False)
preds = []
for i in range(0, len(images), batch_size):
batch = images[i:i + batch_size].to(device)
with torch.no_grad():
pred = inception_model(batch)
preds.append(F.softmax(pred, dim=1).cpu().numpy())
preds = np.concatenate(preds, axis=0)
split_scores = []
for k in range(splits):
part = preds[k * (len(preds) // splits): (k + 1) * (len(preds) // splits)]
py = np.mean(part, axis=0)
scores = part * (np.log(part + 1e-6) - np.log(py + 1e-6))
split_scores.append(np.exp(np.mean(np.sum(scores, axis=1))))
return np.mean(split_scores), np.std(split_scores)
# 初始化最佳 Inception Score 和对应的轮次
best_is = 0
best_epoch = 0
# 训练循环
start_time = time.time()
previous_lr_G = optimizer_G.param_groups[0]['lr']
previous_lr_D = optimizer_D.param_groups[0]['lr']
for epoch in range(n_epochs):
for i, (imgs, labels) in enumerate(train_loader):
batch_size_current = imgs.size(0)
valid = torch.ones(batch_size_current, 1, device=device)
fake = torch.zeros(batch_size_current, 1, device=device)
real_imgs = imgs.to(device)
labels = labels.to(device)
# ---------------------
# 训练判别器
# ---------------------
optimizer_D.zero_grad()
# 判别真实图像
real_pred = discriminator(real_imgs, labels)
d_real_loss = adversarial_loss(real_pred, valid)
# 判别生成图像
z = torch.randn(batch_size_current, latent_dim, device=device)
gen_labels = torch.randint(0, num_classes, (batch_size_current,), device=device)
fake_imgs = generator(z, gen_labels)
fake_pred = discriminator(fake_imgs.detach(), gen_labels)
d_fake_loss = adversarial_loss(fake_pred, fake)
# 判别器总损失
d_loss = (d_real_loss + d_fake_loss) / 2
d_loss.backward()
optimizer_D.step()
# ---------------------
# 训练生成器
# ---------------------
optimizer_G.zero_grad()
validity = discriminator(fake_imgs, gen_labels)
g_loss = adversarial_loss(validity, valid)
g_loss.backward()
optimizer_G.step()
# 更新学习率调度器
scheduler_G.step()
scheduler_D.step()
current_lr_G = scheduler_G.get_last_lr()[0]
current_lr_D = scheduler_D.get_last_lr()[0]
if current_lr_G != previous_lr_G or current_lr_D != previous_lr_D:
print(f"Epoch {epoch+1}/{n_epochs} | LR Changed - Generator: {current_lr_G:.6f}, Discriminator: {current_lr_D:.6f}")
previous_lr_G = current_lr_G
previous_lr_D = current_lr_D
# 保存用于FID计算的真实图片
with torch.no_grad():
real_imgs_count = len(real_imgs)
for idx in range(real_imgs_count):
real_save_path = os.path.join(fid_real_images_path, f"real_img_epoch{epoch+1}_{idx}.png")
vutils.save_image(real_imgs[idx], real_save_path, normalize=True)
# 保存生成图片用于FID计算
with torch.no_grad():
for idx in range(real_imgs_count):
gen_save_path = os.path.join(fid_generated_images_path, f"gen_img_epoch{epoch+1}_{idx}.png")
vutils.save_image(fake_imgs[idx], gen_save_path, normalize=True)
# 计算当前 Inception Score
with torch.no_grad():
num_eval_samples = 1000
eval_batch_size = 100
eval_imgs = []
for _ in range(num_eval_samples // eval_batch_size):
z = torch.randn(eval_batch_size, latent_dim, device=device)
gen_labels = torch.randint(0, num_classes, (eval_batch_size,), device=device)
eval_batch_imgs = generator(z, gen_labels)
eval_imgs.append(eval_batch_imgs.cpu())
eval_imgs = torch.cat(eval_imgs, dim=0)
mean_is, std_is = calculate_inception_score(eval_imgs, batch_size=eval_batch_size, splits=10)
print(f"Epoch {epoch+1}/{n_epochs} | D Loss: {d_loss.item():.4f} | G Loss: {g_loss.item():.4f} | IS: {mean_is:.4f} ± {std_is:.4f}")
# 如果当前 Inception Score 最佳,则保存模型
if mean_is > best_is:
best_is = mean_is
best_epoch = epoch + 1
torch.save({
'epoch': best_epoch,
'generator_state_dict': generator.state_dict(),
'discriminator_state_dict': discriminator.state_dict(),
'optimizer_G_state_dict': optimizer_G.state_dict(),
'optimizer_D_state_dict': optimizer_D.state_dict(),
'best_is': best_is,
}, model_save_path)
print(f"Best model saved from epoch: {best_epoch} with Inception Score: {best_is:.4f}")
end_time = time.time()
total_time = end_time - start_time
hours = int(total_time // 3600)
minutes = int((total_time % 3600) // 60)
seconds = int(total_time % 60)
print(f"Total training time: {hours}h {minutes}m {seconds}s")
训练过程
在训练过程中,我采用了RMSprop优化器,并对生成器和判别器使用MultiStepLR学习率调度策略,以保证训练稳定性和加速收敛。训练总共进行了100个epoch,并且在每个epoch结束时,我计算了Inception Score(IS)评价指标(在每周期生成一千张图片的条件下进行计算IS),用于衡量生成图像的质量。
- Inception Score(IS):通过Inception模型来评估生成图像的多样性与清晰度。
实验结果
在训练过程中,生成器和判别器的损失(Loss)逐渐下降,Inception Score不断提升,表明生成图像质量在不断提高。经过100个epoch的训练,我得到了最佳的Inception Score,达到了4.4049,并保存了最好的模型。
Epoch | D Loss | G Loss | IS Score |
---|---|---|---|
88 | 0.5526 | 1.3599 | 4.4049 |
89 | 0.5867 | 1.3732 | 4.1642 |
90 | 0.5777 | 1.3413 | 3.9381 |
从表中可以看到,尽管生成器损失(G Loss)在某些epoch中有所波动,但最终生成的图像质量在Inception Score上持续提升,展示了模型的强大能力。
- FID Score: 15.8565
- Inception Score: 6.0625 ± 0.7847
- Intra-FID: 51.2626
- Training time:4 hours, 36 minutes, 36 seconds (single RTX3090 operation)
模型优化与挑战
在训练过程中,以下技术和调整对模型的性能提升起到了至关重要的作用:
-
条件批归一化(Conditional BatchNorm2d):该技术使得生成器能够根据标签信息生成指定类别的图像,避免了生成器生成过于模糊或不符合类别特征的图像。
-
自注意力(Self Attention):自注意力机制帮助生成器在生成图像时关注图像中的重要区域,从而提高图像的细节和质量。
-
交叉注意力(Cross Attention):通过与标签信息的交互,生成器能够更好地理解每个类别的特征,使得生成的图像更加准确。
尽管如此,在训练过程中仍然存在一些挑战。例如,生成器和判别器之间的博弈有时会导致训练不稳定,特别是在使用较小的学习率时。因此,在后期,我使用了MultiStepLR学习率调度器来动态调整学习率,从而提高训练的稳定性。
结果复现
import random
import numpy as np
import torch
def set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# 如果使用多个 GPU,设置以下参数
torch.cuda.manual_seed_all(seed)
# 确保 cudnn 的确定性,可能会降低性能
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# 在代码开头设置随机数种子
set_random_seed(42)
未来展望
尽管在本实验中已经取得了较好的生成效果,但仍有一些可以改进的地方。例如,可以进一步优化生成器的架构,探索更多的自监督学习方法,或者尝试不同的损失函数,以期进一步提高图像的质量。此外,还可以尝试使用更复杂的网络结构,如BigGAN等,来提升生成图像的分辨率和细节。
结语
这项实验展示了GAN在图像生成领域的强大潜力,尤其是在CIFAR-100数据集上生成高质量图像的能力。通过合理的模型设计和优化技术,生成图像的质量已经接近真实数据集,并取得了可观的评价结果。如果你也对GAN感兴趣,不妨参考这篇博客并尝试一些不同的改进方法,继续推动生成图像技术的发展。
感谢大家的阅读,如果你有任何问题或想法,欢迎在评论区与我分享!