Pytorch实现之ISRGAN高分辨率图像生成训练自己的低分辨率图像

简介

简介:结合了深层次的残差结构的改进SRGAN模型。在低尺寸图像上执行SR,以使图像重新尺寸重新尺寸时增强和恢复图像的细节更细节更大的尺寸。

论文题目:Improved Generative Adversarial Network for Generating High-Resolution Images from Low-Resolution Images(从低分辨率生成高分辨率的改进GAN)

会议:2023 13th International Conference on Cloud Computing, Data Science & Engineering (Confluence)

摘要:从低分辨率(LR)图像生成高分辨率(HR)图像的过程称为超分辨率(SR)。在本文中,我们尝试使用深度学习技术执行SR。为了在药用成像,取证,模式识别,卫星成像,监视等方面进行更好的性能,需要对图像中特定的关注区域进行缩小,从而使高分辨率必须高分辨率。我们提出了Isrgan(改进的超级分辨率生成对抗网络),这是图像SR的改进版本的Srgan(超级分辨率生成对抗网络)。为了培训网络,已经使用了DIV2K数据集的图像。该数据集由800个不同的图像组成,这些图像已大小为32x32和128x128像素,分别形成LR和HR图像。生成的对抗网络基于对抗训练的想法,包括2个部分,一个歧视者网络和一个生成器网络。发电机从LR图像产生HR图像,而歧视网络将生成的图像分类为假或真实。提出的模型产生了不错的结果,我们的最终超级分辨结果表明,所提出的ISRGAN模型会产生具有增强功能的图像。

模型架构

非常明显的一个特点就是有一个残差块,而且层次非常深,这就使得整个模型结构的参数量非常大,许多低配的电脑甚至服务器根本无法完成整个任务!在拟议的工作中,这是SRGAN的修改和增强版本,以提高和增强模型的有效性。 SRGAN使用具有跳过连接的深度剩余网络(RESNET)和来自预训练的VGG19网络的特征地图。

高分辨率生成器架构

与SRGAN中使用的16个块相比,ISRGAN的生成器网络中使用了20个相同的残留块。其在每个残留块中添加了2个额外的卷积层,然后添加了2个批归一层和2个参数层(PRELU)层。

鉴别器结构 

SRGAN的歧视网络没有残差块。 5个残差块已添加到LSRGAN的鉴别网络中。ISRGAN中的鉴别器网络在残差块中使用卷积歧视器块和参数功能中的泄漏的Relu激活函数。它最终使用具有Sigmoid激活功能的密集层,以将图像分类为真实或生成的图像。

训练自己的数据集 

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision.models import vgg19
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.datasets import ImageFolder
from PIL import Image
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class DResBlock(nn.Module):
    def __init__(self):
        super(DResBlock, self).__init__()
        self.resblock = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.PReLU(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.PReLU(),
            # nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
            # nn.BatchNorm2d(64),
            # nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
            # nn.BatchNorm2d(64)
        )

    def forward(self, x):
        x = self.resblock(x)
        return x

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.conv1_2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=2, padding=1)
        self.lrelu = nn.LeakyReLU()
        self.prelu = nn.PReLU()
        self.resblok1 = DResBlock()
        # self.resblok2 = DResBlock()
        # self.resblok3 = DResBlock()
        # self.resblok4 = DResBlock()
        # self.resblok5 = DResBlock()

        self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.bn64 = nn.BatchNorm2d(64)

        self.block1 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU()
        )
        # self.block2 = nn.Sequential(
        #     nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1),
        #     nn.BatchNorm2d(128),
        #     nn.LeakyReLU()
        # )
        # self.block3 = nn.Sequential(
        #     nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
        #     nn.BatchNorm2d(256),
        #     nn.LeakyReLU()
        # )
        # self.block4 = nn.Sequential(
        #     nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=2, padding=1),
        #     nn.BatchNorm2d(256),
        #     nn.LeakyReLU()
        # )
        # self.block5 = nn.Sequential(
        #     nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),
        #     nn.BatchNorm2d(512),
        #     nn.LeakyReLU()
        # )
        # self.block6 = nn.Sequential(
        #     nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=2, padding=1),
        #     nn.BatchNorm2d(512),
        #     nn.LeakyReLU()
        # )

        self.fc1 = nn.Linear(131072, 1024)
        self.lrelu_fc = nn.LeakyReLU()
        self.fc2 = nn.Linear(1024, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.lrelu(self.bn64(self.conv1_2(self.lrelu(self.conv1(x)))))
        out = self.resblok1(x)
        out1 = out + x
        # x = self.resblok2(out1)
        # out2 = x + out1
        # x = self.resblok2(out2)
        # out3 = x + out2
        # x = self.resblok2(out3)
        # out4 = x + out3
        # x = self.resblok2(out4)
        # out5 = x + out4
        # out = self.bn64(self.conv2(out5))
        # out = out + out5

        out = self.block1(out1)
        # out = self.block2(out)
        # out = self.block3(out)
        # out = self.block4(out)
        # out = self.block5(out)
        # out = self.block6(out)

        out = out.view(out.size(0), -1)
        out = self.fc1(out)
        out = self.lrelu_fc(out)
        out = self.fc2(out)
        out = self.sigmoid(out)
        return out

class GResBlock(nn.Module):
    def __init__(self):
        super(GResBlock, self).__init__()
        self.resblock = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.PReLU(),
            # nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
            # nn.BatchNorm2d(64),
            # nn.PReLU(),
            # nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
            # nn.BatchNorm2d(64),
            # nn.PReLU(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64)
        )

    def forward(self, x):
        x = self.resblock(x)
        return x

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=9, stride=1, padding=4)  
        self.prelu = nn.PReLU()
        self.block = GResBlock()
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=256, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.conv5 = nn.Conv2d(in_channels=16, out_channels=3, kernel_size=9, stride=1, padding=4) 
        self.bn = nn.BatchNorm2d(64)
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor=2)

    def forward(self, x):
        x = self.conv1(x)
        x = self.prelu(x)

        for i in range(1):
            x = self.block(x) + x

        x = self.bn(self.conv2(x))
        x = self.prelu(self.pixel_shuffle(self.conv3(x)))
        x = self.prelu(self.pixel_shuffle(self.conv4(x)))
        x = self.conv5(x)
        return x

class ContentLoss(nn.Module):
    def __init__(self):
        super(ContentLoss, self).__init__()
        vgg = vgg19(pretrained=True).features
        self.feature_extractor = nn.Sequential(*list(vgg.children())[:35])
        for param in self.feature_extractor.parameters():
            param.requires_grad = False
        self.feature_extractor = self.feature_extractor.to(device)

    def forward(self, sr, hr):
        if sr.dim() != 4 or hr.dim() != 4:
            raise ValueError(f"Expected 4D input, but got sr: {sr.dim()}D, hr: {hr.dim()}D")

        sr_features = self.feature_extractor(sr)
        hr_features = self.feature_extractor(hr)
        return F.mse_loss(sr_features, hr_features)

# 定义对抗损失
class AdversarialLoss(nn.Module):
    def __init__(self):
        super(AdversarialLoss, self).__init__()

    def forward(self, sr_output):
        return F.binary_cross_entropy(sr_output, torch.ones_like(sr_output))

# 定义总损失
class TotalLoss(nn.Module):
    def __init__(self, content_weight=1e-3, adversarial_weight=1e-3):
        super(TotalLoss, self).__init__()
        self.content_loss = ContentLoss()
        self.adversarial_loss = AdversarialLoss()
        self.content_weight = content_weight
        self.adversarial_weight = adversarial_weight

    def forward(self, sr, hr, sr_output):
        # Print shapes for debugging
        print(f"sr shape: {sr.shape}, hr shape: {hr.shape}")

        content_loss = self.content_loss(sr, hr)
        adversarial_loss = self.adversarial_loss(sr_output)
        total_loss = self.content_weight * content_loss + self.adversarial_weight * adversarial_loss
        return total_loss, content_loss, adversarial_loss


# 自定义数据集类
class SRDataset(Dataset):
    def __init__(self, lr_dir, hr_dir, lr_transform=None, hr_transform=None):
        self.lr_dir = lr_dir
        self.hr_dir = hr_dir
        self.lr_transform = lr_transform
        self.hr_transform = hr_transform
        self.lr_files = sorted(os.listdir(lr_dir))
        self.hr_files = sorted(os.listdir(hr_dir))

    def __len__(self):
        return len(self.lr_files)

    def __getitem__(self, idx):
        lr_path = os.path.join(self.lr_dir, self.lr_files[idx])
        hr_path = os.path.join(self.hr_dir, self.hr_files[idx])

        lr_img = Image.open(lr_path).convert('RGB')
        hr_img = Image.open(hr_path).convert('RGB')

        if self.lr_transform:
            lr_img = self.lr_transform(lr_img)
        if self.hr_transform:
            hr_img = self.hr_transform(hr_img)

        return lr_img, hr_img


transform = transforms.Compose([
    transforms.Resize(16),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

transform2 = transforms.Compose([
    transforms.Resize(64),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

lr_dir = r"低分辨率图像"
hr_dir = r"高分辨率图像"


train_dataset = SRDataset(lr_dir, hr_dir, lr_transform=transform, hr_transform=transform2)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)


generator = Generator().to(device)
discriminator = Discriminator().to(device)
g_optimizer = optim.Adam(generator.parameters(), lr=1e-4, betas=(0.9, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=1e-4, betas=(0.9, 0.999))
total_loss = TotalLoss(content_weight=1e-3, adversarial_weight=1e-3)


num_epochs = 100
print_every = 10

for epoch in range(num_epochs):
    generator.train()
    discriminator.train()

    for i, (lr_imgs, hr_imgs) in enumerate(train_loader):
        lr_imgs = lr_imgs.to(device)
        hr_imgs = hr_imgs.to(device)

      
        g_optimizer.zero_grad()
        sr_imgs = generator(lr_imgs)
        sr_output = discriminator(sr_imgs)
        g_loss, content_loss, adversarial_loss = total_loss(sr_imgs, hr_imgs, sr_output)
        g_loss.backward()
        g_optimizer.step()

        
        d_optimizer.zero_grad()
        real_output = discriminator(hr_imgs)
        fake_output = discriminator(sr_imgs.detach())
        d_loss_real = F.binary_cross_entropy(real_output, torch.ones_like(real_output))
        d_loss_fake = F.binary_cross_entropy(fake_output, torch.zeros_like(fake_output))
        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        d_optimizer.step()

      
        if (i + 1) % print_every == 0:
            print(f"Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(train_loader)}], "
                  f"G Loss: {g_loss.item():.4f}, D Loss: {d_loss.item():.4f}, "
                  f"Content Loss: {content_loss.item():.4f}, Adversarial Loss: {adversarial_loss.item():.4f}")

torch.save(generator.state_dict(), f"generator_epoch_{epoch + 1}.pth")

代码注意事项

在代码的实现中,有几个注意点

1.网络模型很大,我们的残差块设计的很小很浅,如果扩大了残差块的设计,在fc1层的数据需要改变!

2.目前的设计是,低分辨率图像与高分辨率的图像大小不一样,后者是前者的4倍。

3.数据集下的高分辨率图像和低分辨率图像的数量要一样多,不要乱给!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值