【超分辨率】基于DDIM+SwinUnet实现超分辨率

详细代码及训练得到的8倍超分辨率模型已放在GitHub

Github: SuperResolution-DDIM-SwinUnet

简介

  • 在DIV2K数据集(800张2K图像)上训练了一个8倍超分辨率模型,采用了和sr3一样的:将低分辨率图像和噪声拼接输入模型。不过没有采用sr3的直接输入噪声强度,而是继续沿用输入去燥步骤t的方法,并增加了DDPM的步数到1000(如果仅是100步的话,输出结果的噪点会比较多)。

  • 效果图放在了Github的result目录里,引入了DDIM采样(这也是使用t作为时间条件的好处),从结果看DDIM仅需采样40步效果就和DDPM采样1000步相当了。而DDIM采样1步或2步也能大体还原,不过质量不高。

不足:

1.可能是使用SwinUnet的关系,超分辨率后的图像总是能隐约看到“小框框”;而且图像大小必须能被256整除(这个其实好解决,resize即可)。
2.只做了一个8倍超分辨率的模型(倍数太大,从效果来看失真率很高),可以考虑做倍率较低的比如2倍和4倍,进行拼接从而实现8倍的效果,可能失真率会好一点。

代码:(run.py、scheduler.py、SwinUnet.py、load_data.py、training.py)
"run.py"
import numpy as np
import torch

from SwinUnet import SwinUnet
from scheduler import Scheduler
from PIL import Image

import argparse
import datetime
import os

def main(args):

    device = torch.device(args.device)

    model = SwinUnet(channels=3, dim=96, mlp_ratio=4, patch_size=4, window_size=8,
                     depth=[2, 2, 6, 2], nheads=[3, 6, 12, 24]).to(device)

    sr_ratio = args.sr_ratio

    model.load_state_dict(torch.load(args.model_path, map_location=device))
    model.eval()

    scheduler = Scheduler(model, args.denoise_steps)

    image_path = args.image_path
    img = Image.open(image_path)

    img_size = img.size

    assert img_size[0] >= 256 and img_size[1] >= 256, "图片的最小尺寸为256"

    img_size = (
        (img_size[0] // 256) * 256 * sr_ratio,
        (img_size[1] // 256) * 256 * sr_ratio
    )

    img = img.resize(img_size)

    img_arr = np.array(img)

    if img_arr.shape[-1] == 4: img_arr = img_arr[..., :3]

    img_arr = img_arr.transpose(2, 0, 1) / 255.

    img_arr = 2 * (img_arr - 0.5)

    img_arr = torch.from_numpy(img_arr).float().to(device)
    img_arr = img_arr.unsqueeze(0)

    if args.use_ddim:
        y = scheduler.ddim(img_arr, device, sub_sequence_step=args.ddim_sub_sequence_steps)[-1]
    else:
        y = scheduler.ddpm(img_arr, device)[-1]

    y = y.transpose(1, 2, 0)
    y = (y + 1.) / 2
    y *= 255.0

    new_img = Image.fromarray(y.astype(np.uint8))

    new_img.save(os.path.join(args.results_dir, str(datetime.datetime.now()) + ".png"))


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument("--device", type=str, default="cpu")
    parser.add_argument("--image_path", type=str)
    parser.add_argument("--sr_ratio", type=int, default=8)
    parser.add_argument("--results_dir", type=str, default="./results")
    parser.add_argument("--denoise_steps", type=int, default=1000)
    parser.add_argument("--model_path", type=str, default="SwinUNet-SR8.pth")
    parser.add_argument("--use_ddim", type=int, default=1)
    parser.add_argument("--ddim_sub_sequence_steps", type=int, default=25)
    args = parser.parse_args()
    main(args)
"scheduler.py"
import numpy as np

import torch

import torch.nn.functional as F

from tqdm import tqdm

def extract_into_tensor(arr, timesteps, broadcast_shape):

    res = torch.from_numpy(arr).to(torch.float32).to(device=timesteps.device)[timesteps]
    while len(res.shape) < len(broadcast_shape):
        res = res[..., None]
    return res + torch.zeros(broadcast_shape, device=timesteps.device)

class Scheduler:

    def __init__(self, denoise_model, denoise_steps, beta_start=1e-4, beta_end=0.005):

        self.model = denoise_model

        betas = np.array(
            np.linspace(beta_start, beta_end, denoise_steps),
            dtype=np.float64
        )

        self.denoise_steps = denoise_steps

        assert len(betas.shape) == 1, "betas must be 1-D"
        assert (betas > 0).all() and (betas <= 1).all()

        alphas = 1.0 - betas

        self.sqrt_alphas = np.sqrt(alphas)
        self.one_minus_alphas = 1.0 - alphas
        self.alphas_cumprod = np.cumprod(alphas, axis=0)

        self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)

        self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])

    def q_sample(self, y0, t, noise):

        return (
            extract_into_tensor(self.sqrt_alphas_cumprod, t, y0.shape) * y0
            + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, y0.shape) * noise
        )

    def training_losses(self, x, y, t):

        noise = torch.randn_like(y)
        y_t = self.q_sample(y, t, noise)

        predict_noise = self.model(torch.cat([x, y_t], dim=1), t)

        return F.mse_loss(predict_noise, noise)


    @torch.no_grad()
    def ddpm(self, x, device):

        y = torch.randn(*x.shape, device=device)

        for t in tqdm(reversed(range(0, self.denoise_steps)), total=self.denoise_steps):

            t = torch.tensor([t], device=device).repeat(x.shape[0])
            t_mask = (t != 0).float().view(-1, *([1] * (len(y.shape) - 1)))

            eps = self.model(torch.cat([x, y], dim=1), t)

            y = y - (
                    extract_into_tensor(self
### 基于 Transformer 的图像超分辨率方法概述 近年来,基于 Transformer 的架构逐渐被应用于计算机视觉领域中的各种任务,其中包括图像超分辨率 (Super-Resolution, SR)[^1]。传统的超分辨率模型主要依赖卷积神经网络(CNNs),例如 SRCNN、VDSR 和 ESRGAN 等。然而,随着 Vision Transformers (ViTs) 的兴起,研究人员发现 Transformer 架构能够在捕捉全局上下文信息方面表现出更强的能力。 #### 方法原理 基于 Transformer 的超分辨率技术的核心在于通过自注意力机制捕获长距离的空间关系。具体来说,输入的低分辨率图像会被划分为多个固定大小的小块(patches)。这些 patches 被展平成一维向量,并嵌入到更高维度空间中以便后续处理[^2]。接着,Transformer Encoder 层会利用缩放点积注意力机制来学习 patch 序列间的相互作用,从而增强对复杂纹理的理解能力。 对于最终重建高分辨率图像的过程,则采用解码器结构或者直接叠加线性映射操作完成从特征表示回到像素域的任务转换。此外,在一些先进的设计方案里还会加入残差连接以及多尺度特征融合模块以进一步优化性能表现[^3]。 以下是实现该类算法的一个简化版本代码示例: ```python import torch from torch import nn class PatchEmbedding(nn.Module): def __init__(self, img_size=64, patch_size=8, embed_dim=768): super().__init__() self.proj = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self, x): x = self.proj(x).flatten(2).transpose(1, 2) return x class TransformerBlock(nn.Module): """Basic transformer block.""" ... def build_model(): model = nn.Sequential( PatchEmbedding(), *[TransformerBlock() for _ in range(12)], # Example number of layers. nn.Linear(embed_dim, num_patches * channels_per_patch), Rearrange('b (h w c) -> b c h w', ...) ) return model ``` 以上仅为概念验证性质的基础框架示意,实际应用需依据特定需求调整参数设置与组件细节[^4]。 ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值