简介
简介:结合了深层次的残差结构的改进SRGAN模型。在低尺寸图像上执行SR,以使图像重新尺寸重新尺寸时增强和恢复图像的细节更细节更大的尺寸。
论文题目:Improved Generative Adversarial Network for Generating High-Resolution Images from Low-Resolution Images(从低分辨率生成高分辨率的改进GAN)
会议:2023 13th International Conference on Cloud Computing, Data Science & Engineering (Confluence)
摘要:从低分辨率(LR)图像生成高分辨率(HR)图像的过程称为超分辨率(SR)。在本文中,我们尝试使用深度学习技术执行SR。为了在药用成像,取证,模式识别,卫星成像,监视等方面进行更好的性能,需要对图像中特定的关注区域进行缩小,从而使高分辨率必须高分辨率。我们提出了Isrgan(改进的超级分辨率生成对抗网络),这是图像SR的改进版本的Srgan(超级分辨率生成对抗网络)。为了培训网络,已经使用了DIV2K数据集的图像。该数据集由800个不同的图像组成,这些图像已大小为32x32和128x128像素,分别形成LR和HR图像。生成的对抗网络基于对抗训练的想法,包括2个部分,一个歧视者网络和一个生成器网络。发电机从LR图像产生HR图像,而歧视网络将生成的图像分类为假或真实。提出的模型产生了不错的结果,我们的最终超级分辨结果表明,所提出的ISRGAN模型会产生具有增强功能的图像。
模型架构
非常明显的一个特点就是有一个残差块,而且层次非常深,这就使得整个模型结构的参数量非常大,许多低配的电脑甚至服务器根本无法完成整个任务!在拟议的工作中,这是SRGAN的修改和增强版本,以提高和增强模型的有效性。 SRGAN使用具有跳过连接的深度剩余网络(RESNET)和来自预训练的VGG19网络的特征地图。
高分辨率生成器架构
与SRGAN中使用的16个块相比,ISRGAN的生成器网络中使用了20个相同的残留块。其在每个残留块中添加了2个额外的卷积层,然后添加了2个批归一层和2个参数层(PRELU)层。
鉴别器结构
SRGAN的歧视网络没有残差块。 5个残差块已添加到LSRGAN的鉴别网络中。ISRGAN中的鉴别器网络在残差块中使用卷积歧视器块和参数功能中的泄漏的Relu激活函数。它最终使用具有Sigmoid激活功能的密集层,以将图像分类为真实或生成的图像。
训练自己的数据集
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision.models import vgg19
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.datasets import ImageFolder
from PIL import Image
import os
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class DResBlock(nn.Module):
def __init__(self):
super(DResBlock, self).__init__()
self.resblock = nn.Sequential(
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64),
nn.PReLU(),
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64),
nn.PReLU(),
# nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
# nn.BatchNorm2d(64),
# nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
# nn.BatchNorm2d(64)
)
def forward(self, x):
x = self.resblock(x)
return x
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1)
self.conv1_2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=2, padding=1)
self.lrelu = nn.LeakyReLU()
self.prelu = nn.PReLU()
self.resblok1 = DResBlock()
# self.resblok2 = DResBlock()
# self.resblok3 = DResBlock()
# self.resblok4 = DResBlock()
# self.resblok5 = DResBlock()
self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1)
self.bn64 = nn.BatchNorm2d(64)
self.block1 = nn.Sequential(
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU()
)
# self.block2 = nn.Sequential(
# nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1),
# nn.BatchNorm2d(128),
# nn.LeakyReLU()
# )
# self.block3 = nn.Sequential(
# nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
# nn.BatchNorm2d(256),
# nn.LeakyReLU()
# )
# self.block4 = nn.Sequential(
# nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=2, padding=1),
# nn.BatchNorm2d(256),
# nn.LeakyReLU()
# )
# self.block5 = nn.Sequential(
# nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),
# nn.BatchNorm2d(512),
# nn.LeakyReLU()
# )
# self.block6 = nn.Sequential(
# nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=2, padding=1),
# nn.BatchNorm2d(512),
# nn.LeakyReLU()
# )
self.fc1 = nn.Linear(131072, 1024)
self.lrelu_fc = nn.LeakyReLU()
self.fc2 = nn.Linear(1024, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = self.lrelu(self.bn64(self.conv1_2(self.lrelu(self.conv1(x)))))
out = self.resblok1(x)
out1 = out + x
# x = self.resblok2(out1)
# out2 = x + out1
# x = self.resblok2(out2)
# out3 = x + out2
# x = self.resblok2(out3)
# out4 = x + out3
# x = self.resblok2(out4)
# out5 = x + out4
# out = self.bn64(self.conv2(out5))
# out = out + out5
out = self.block1(out1)
# out = self.block2(out)
# out = self.block3(out)
# out = self.block4(out)
# out = self.block5(out)
# out = self.block6(out)
out = out.view(out.size(0), -1)
out = self.fc1(out)
out = self.lrelu_fc(out)
out = self.fc2(out)
out = self.sigmoid(out)
return out
class GResBlock(nn.Module):
def __init__(self):
super(GResBlock, self).__init__()
self.resblock = nn.Sequential(
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64),
nn.PReLU(),
# nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
# nn.BatchNorm2d(64),
# nn.PReLU(),
# nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
# nn.BatchNorm2d(64),
# nn.PReLU(),
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64)
)
def forward(self, x):
x = self.resblock(x)
return x
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=9, stride=1, padding=4)
self.prelu = nn.PReLU()
self.block = GResBlock()
self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1)
self.conv3 = nn.Conv2d(in_channels=64, out_channels=256, kernel_size=3, stride=1, padding=1)
self.conv4 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1)
self.conv5 = nn.Conv2d(in_channels=16, out_channels=3, kernel_size=9, stride=1, padding=4)
self.bn = nn.BatchNorm2d(64)
self.pixel_shuffle = nn.PixelShuffle(upscale_factor=2)
def forward(self, x):
x = self.conv1(x)
x = self.prelu(x)
for i in range(1):
x = self.block(x) + x
x = self.bn(self.conv2(x))
x = self.prelu(self.pixel_shuffle(self.conv3(x)))
x = self.prelu(self.pixel_shuffle(self.conv4(x)))
x = self.conv5(x)
return x
class ContentLoss(nn.Module):
def __init__(self):
super(ContentLoss, self).__init__()
vgg = vgg19(pretrained=True).features
self.feature_extractor = nn.Sequential(*list(vgg.children())[:35])
for param in self.feature_extractor.parameters():
param.requires_grad = False
self.feature_extractor = self.feature_extractor.to(device)
def forward(self, sr, hr):
if sr.dim() != 4 or hr.dim() != 4:
raise ValueError(f"Expected 4D input, but got sr: {sr.dim()}D, hr: {hr.dim()}D")
sr_features = self.feature_extractor(sr)
hr_features = self.feature_extractor(hr)
return F.mse_loss(sr_features, hr_features)
# 定义对抗损失
class AdversarialLoss(nn.Module):
def __init__(self):
super(AdversarialLoss, self).__init__()
def forward(self, sr_output):
return F.binary_cross_entropy(sr_output, torch.ones_like(sr_output))
# 定义总损失
class TotalLoss(nn.Module):
def __init__(self, content_weight=1e-3, adversarial_weight=1e-3):
super(TotalLoss, self).__init__()
self.content_loss = ContentLoss()
self.adversarial_loss = AdversarialLoss()
self.content_weight = content_weight
self.adversarial_weight = adversarial_weight
def forward(self, sr, hr, sr_output):
# Print shapes for debugging
print(f"sr shape: {sr.shape}, hr shape: {hr.shape}")
content_loss = self.content_loss(sr, hr)
adversarial_loss = self.adversarial_loss(sr_output)
total_loss = self.content_weight * content_loss + self.adversarial_weight * adversarial_loss
return total_loss, content_loss, adversarial_loss
# 自定义数据集类
class SRDataset(Dataset):
def __init__(self, lr_dir, hr_dir, lr_transform=None, hr_transform=None):
self.lr_dir = lr_dir
self.hr_dir = hr_dir
self.lr_transform = lr_transform
self.hr_transform = hr_transform
self.lr_files = sorted(os.listdir(lr_dir))
self.hr_files = sorted(os.listdir(hr_dir))
def __len__(self):
return len(self.lr_files)
def __getitem__(self, idx):
lr_path = os.path.join(self.lr_dir, self.lr_files[idx])
hr_path = os.path.join(self.hr_dir, self.hr_files[idx])
lr_img = Image.open(lr_path).convert('RGB')
hr_img = Image.open(hr_path).convert('RGB')
if self.lr_transform:
lr_img = self.lr_transform(lr_img)
if self.hr_transform:
hr_img = self.hr_transform(hr_img)
return lr_img, hr_img
transform = transforms.Compose([
transforms.Resize(16),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
transform2 = transforms.Compose([
transforms.Resize(64),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
lr_dir = r"低分辨率图像"
hr_dir = r"高分辨率图像"
train_dataset = SRDataset(lr_dir, hr_dir, lr_transform=transform, hr_transform=transform2)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
generator = Generator().to(device)
discriminator = Discriminator().to(device)
g_optimizer = optim.Adam(generator.parameters(), lr=1e-4, betas=(0.9, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=1e-4, betas=(0.9, 0.999))
total_loss = TotalLoss(content_weight=1e-3, adversarial_weight=1e-3)
num_epochs = 100
print_every = 10
for epoch in range(num_epochs):
generator.train()
discriminator.train()
for i, (lr_imgs, hr_imgs) in enumerate(train_loader):
lr_imgs = lr_imgs.to(device)
hr_imgs = hr_imgs.to(device)
g_optimizer.zero_grad()
sr_imgs = generator(lr_imgs)
sr_output = discriminator(sr_imgs)
g_loss, content_loss, adversarial_loss = total_loss(sr_imgs, hr_imgs, sr_output)
g_loss.backward()
g_optimizer.step()
d_optimizer.zero_grad()
real_output = discriminator(hr_imgs)
fake_output = discriminator(sr_imgs.detach())
d_loss_real = F.binary_cross_entropy(real_output, torch.ones_like(real_output))
d_loss_fake = F.binary_cross_entropy(fake_output, torch.zeros_like(fake_output))
d_loss = d_loss_real + d_loss_fake
d_loss.backward()
d_optimizer.step()
if (i + 1) % print_every == 0:
print(f"Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(train_loader)}], "
f"G Loss: {g_loss.item():.4f}, D Loss: {d_loss.item():.4f}, "
f"Content Loss: {content_loss.item():.4f}, Adversarial Loss: {adversarial_loss.item():.4f}")
torch.save(generator.state_dict(), f"generator_epoch_{epoch + 1}.pth")
代码注意事项
在代码的实现中,有几个注意点:
1.网络模型很大,我们的残差块设计的很小很浅,如果扩大了残差块的设计,在fc1层的数据需要改变!
2.目前的设计是,低分辨率图像与高分辨率的图像大小不一样,后者是前者的4倍。
3.数据集下的高分辨率图像和低分辨率图像的数量要一样多,不要乱给!