【论文复现】SRGAN

1. 项目结构

如何生成文件夹的文件目录呢?

按住shift键,右击你要生成目录的文件夹,选择“在此处打开Powershell窗口”

在命令窗口里输入命令“tree”,按回车。就会显示出目录结构。

├─.idea
│  └─inspectionProfiles
├─benchmark_results
├─data
│  ├─test
│  │  ├─Manga109
│  │  ├─Set14
│  │  ├─Set5
│  │  └─Urban100
│  ├─train_DIV2K_HR
│  └─valid_DIV2K_HR
├─epochs
├─pytorch_ssim
│  └─__pycache__
├─statistics
├─training_results
│  └─SRF_4
└─__pycache__

为了更好地记录这个代码文件夹的结构,我再把.py文件添上去

├─.idea
│  └─inspectionProfiles
├─benchmark_results
├─data
│  ├─test
│  │  ├─Manga109
│  │  ├─Set14
│  │  ├─Set5
│  │  └─Urban100
│  ├─train_DIV2K_HR
│  └─valid_DIV2K_HR
├─epochs
├─pytorch_ssim
│  └─__pycache__
├─statistics
├─training_results
│  └─SRF_4
├─data_utils.py
├─loss.py
├─model.py
├─README.md
├─test_benchmark.py
├─test_image.py
└─train.py

 才拿到代码包的时候,每一个空文件夹下都有一个“.gitkeep文件”。

那么什么是“.gitkeep文件”呢?

因为Git 是一个文件追踪系统,所以Git 不会追踪一个空目录。当我们需要保留空目录的时候,“.gitkeep文件”可以使 Git 保留一个空文件夹。


2. 实验细节

 算法名称 SRGAN
图像域 RGB
下采样方法 双三次核函数下采样4⨉
目标函数 内容损失+对抗损失
生成器 SRResNet
判别器 VGG:判别HR与SR
训练集 DIV2K,800张
验证集 DIV2K,100张
测试集 Set5、Set14、BSD100、Urban100、Manga109

参数配置(在train.py中)

parser = argparse.ArgumentParser(description='Train Super Resolution Models')
parser.add_argument('--crop_size', default=88, type=int, help='training images crop size')
parser.add_argument('--upscale_factor', default=4, type=int, choices=[2, 4, 8],
                    help='super resolution upscale factor')
parser.add_argument('--num_epochs', default=100, type=int, help='train epoch number')

GPU 

为了方便,然后训练集本来也不大,就在本地的NVIDIA GeForce RTX 3050上跑的

持续运行

因为经常把电脑背来背去的,会放进包里,所以要求电脑合上的时候程序也能继续运行。具体实现方法是:

1、点击开始图标,点击控制面板。

2、查看方式选择为“类别”,找到“硬件和声音”功能并点击。

3、在硬件和声音页面,找到更改电源按钮的功能选项并点击。

4、将“关闭盖子时”后方都设置为“不采取任何操作”,最后保存修改即可。

3. 项目解析 

benchmark_results

训练完成后,训练结果会保存到benchmark_results 文件夹中

data

存放训练集、验证集、测试集的地方。

epochs

用于存放每个epoch训练得到的生成器和判别器的模型参数。

pytorch_ssim

计算结构相似性指数SSIM

statistics

存放记录每个epoch训练结果的表格,每跑10个epochs记录一次

training_results

存放验证集结果

里面有一个名为“SRF_4”的文件夹,意思是4⨉的双三次核函数下采样、放大因子为4。

“SRF_4”的文件夹存放着每一个epoch在验证集上的可视化结果,于展示图像超分辨率模型在训练过程中的性能表现。

每组图片包含三列:原始低分辨率图像(val_hr_restore)、对应的高分辨率图像(val_hr)以及模型生成的超分辨率图像(sr)。

data_utils.py

数据集加载

train.py有一行导包代码

from data_utils import TrainDatasetFromFolder, ValDatasetFromFolder, display_transform

所以data_utils.py有一些有关数据集加载的函数以供train.py使用。

都是从库里导的包

from os import listdir
from os.path import join
from PIL import Image
from torch.utils.data.dataset import Dataset
from torchvision.transforms import Compose, RandomCrop, ToTensor, ToPILImage, CenterCrop, Resize

对图像的一些处理 

# 判断文件名是否为常见图像文件格式(不区分大小写)
def is_image_file(filename):
    return any(filename.endswith(extension) for extension in ['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG'])

# 根据裁剪尺寸和放大因子计算有效的裁剪尺寸,确保能被放大因子整除
def calculate_valid_crop_size(crop_size, upscale_factor):
    return crop_size - (crop_size % upscale_factor)

# 定义高分辨率训练图像的变换操作:随机裁剪后转换为张量
def train_hr_transform(crop_size):
    return Compose([
        RandomCrop(crop_size),# 随机裁剪图像,裁剪尺寸为传入的crop_size
        ToTensor(),# 将裁剪后的图像转换为张量格式
    ])

# 定义低分辨率训练图像的变换操作:先转换为PIL图像,缩放后再转换为张量
def train_lr_transform(crop_size, upscale_factor):
    return Compose([
        ToPILImage(),# 将高分辨率图像张量转换为PIL图像对象
        Resize(crop_size // upscale_factor, interpolation=Image.BICUBIC),# 将图像缩小为原始尺寸除以放大因子,使用双三次插值
        ToTensor()# 将缩放后的低分辨率图像转换为张量格式
    ])

# 定义用于显示图像的变换操作:调整大小、中心裁剪后转换为张量
def display_transform():
    return Compose([
        ToPILImage(),  # 将图像转换为PIL图像对象
        Resize(400),  # 将图像大小调整为400(可能是为了统一显示尺寸)
        CenterCrop(400),  # 进行中心裁剪,确保关键部分完整
        ToTensor()  # 将处理后的图像转换为张量格式
    ])

从文件夹中加载和预处理训练图像数据

class TrainDatasetFromFolder(Dataset):
    def __init__(self, dataset_dir, crop_size, upscale_factor):
        super(TrainDatasetFromFolder, self).__init__()
        self.image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir) if is_image_file(x)]
        crop_size = calculate_valid_crop_size(crop_size, upscale_factor)
        self.hr_transform = train_hr_transform(crop_size)
        self.lr_transform = train_lr_transform(crop_size, upscale_factor)

    def __getitem__(self, index):
        hr_image = self.hr_transform(Image.open(self.image_filenames[index]))
        lr_image = self.lr_transform(hr_image)
        return lr_image, hr_image

    def __len__(self):
        return len(self.image_filenames)

从文件夹中加载和预处理验证图像数据 

class ValDatasetFromFolder(Dataset):
    def __init__(self, dataset_dir, upscale_factor):
        super(ValDatasetFromFolder, self).__init__()
        self.upscale_factor = upscale_factor
        self.image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir) if is_image_file(x)]

    def __getitem__(self, index):
        hr_image = Image.open(self.image_filenames[index])
        w, h = hr_image.size
        crop_size = calculate_valid_crop_size(min(w, h), self.upscale_factor)
        lr_scale = Resize(crop_size // self.upscale_factor, interpolation=Image.BICUBIC)
        hr_scale = Resize(crop_size, interpolation=Image.BICUBIC)
        hr_image = CenterCrop(crop_size)(hr_image)
        lr_image = lr_scale(hr_image)
        hr_restore_img = hr_scale(lr_image)
        return ToTensor()(lr_image), ToTensor()(hr_restore_img), ToTensor()(hr_image)

    def __len__(self):
        return len(self.image_filenames)

 从文件夹中加载和预处理测试图像数据 

class TestDatasetFromFolder(Dataset):
    def __init__(self, dataset_dir, upscale_factor):
        super(TestDatasetFromFolder, self).__init__()
        self.lr_path = dataset_dir + '/SRF_' + str(upscale_factor) + '/data/'
        self.hr_path = dataset_dir + '/SRF_' + str(upscale_factor) + '/target/'
        self.upscale_factor = upscale_factor
        self.lr_filenames = [join(self.lr_path, x) for x in listdir(self.lr_path) if is_image_file(x)]
        self.hr_filenames = [join(self.hr_path, x) for x in listdir(self.hr_path) if is_image_file(x)]

    def __getitem__(self, index):
        image_name = self.lr_filenames[index].split('/')[-1]
        lr_image = Image.open(self.lr_filenames[index])
        w, h 
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值