详细代码及训练得到的8倍超分辨率模型已放在GitHub
Github: SuperResolution-DDIM-SwinUnet
简介
-
在DIV2K数据集(800张2K图像)上训练了一个8倍超分辨率模型,采用了和sr3一样的:将低分辨率图像和噪声拼接输入模型。不过没有采用sr3的直接输入噪声强度,而是继续沿用输入去燥步骤t的方法,并增加了DDPM的步数到1000(如果仅是100步的话,输出结果的噪点会比较多)。
-
效果图放在了Github的result目录里,引入了DDIM采样(这也是使用t作为时间条件的好处),从结果看DDIM仅需采样40步效果就和DDPM采样1000步相当了。而DDIM采样1步或2步也能大体还原,不过质量不高。
不足:
1.可能是使用SwinUnet的关系,超分辨率后的图像总是能隐约看到“小框框”;而且图像大小必须能被256整除(这个其实好解决,resize即可)。
2.只做了一个8倍超分辨率的模型(倍数太大,从效果来看失真率很高),可以考虑做倍率较低的比如2倍和4倍,进行拼接从而实现8倍的效果,可能失真率会好一点。
代码:(run.py、scheduler.py、SwinUnet.py、load_data.py、training.py)
"run.py"
import numpy as np
import torch
from SwinUnet import SwinUnet
from scheduler import Scheduler
from PIL import Image
import argparse
import datetime
import os
def main(args):
device = torch.device(args.device)
model = SwinUnet(channels=3, dim=96, mlp_ratio=4, patch_size=4, window_size=8,
depth=[2, 2, 6, 2], nheads=[3, 6, 12, 24]).to(device)
sr_ratio = args.sr_ratio
model.load_state_dict(torch.load(args.model_path, map_location=device))
model.eval()
scheduler = Scheduler(model, args.denoise_steps)
image_path = args.image_path
img = Image.open(image_path)
img_size = img.size
assert img_size[0] >= 256 and img_size[1] >= 256, "图片的最小尺寸为256"
img_size = (
(img_size[0] // 256) * 256 * sr_ratio,
(img_size[1] // 256) * 256 * sr_ratio
)
img = img.resize(img_size)
img_arr = np.array(img)
if img_arr.shape[-1] == 4: img_arr = img_arr[..., :3]
img_arr = img_arr.transpose(2, 0, 1) / 255.
img_arr = 2 * (img_arr - 0.5)
img_arr = torch.from_numpy(img_arr).float().to(device)
img_arr = img_arr.unsqueeze(0)
if args.use_ddim:
y = scheduler.ddim(img_arr, device, sub_sequence_step=args.ddim_sub_sequence_steps)[-1]
else:
y = scheduler.ddpm(img_arr, device)[-1]
y = y.transpose(1, 2, 0)
y = (y + 1.) / 2
y *= 255.0
new_img = Image.fromarray(y.astype(np.uint8))
new_img.save(os.path.join(args.results_dir, str(datetime.datetime.now()) + ".png"))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--device", type=str, default="cpu")
parser.add_argument("--image_path", type=str)
parser.add_argument("--sr_ratio", type=int, default=8)
parser.add_argument("--results_dir", type=str, default="./results")
parser.add_argument("--denoise_steps", type=int, default=1000)
parser.add_argument("--model_path", type=str, default="SwinUNet-SR8.pth")
parser.add_argument("--use_ddim", type=int, default=1)
parser.add_argument("--ddim_sub_sequence_steps", type=int, default=25)
args = parser.parse_args()
main(args)
"scheduler.py"
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm
def extract_into_tensor(arr, timesteps, broadcast_shape):
res = torch.from_numpy(arr).to(torch.float32).to(device=timesteps.device)[timesteps]
while len(res.shape) < len(broadcast_shape):
res = res[..., None]
return res + torch.zeros(broadcast_shape, device=timesteps.device)
class Scheduler:
def __init__(self, denoise_model, denoise_steps, beta_start=1e-4, beta_end=0.005):
self.model = denoise_model
betas = np.array(
np.linspace(beta_start, beta_end, denoise_steps),
dtype=np.float64
)
self.denoise_steps = denoise_steps
assert len(betas.shape) == 1, "betas must be 1-D"
assert (betas > 0).all() and (betas <= 1).all()
alphas = 1.0 - betas
self.sqrt_alphas = np.sqrt(alphas)
self.one_minus_alphas = 1.0 - alphas
self.alphas_cumprod = np.cumprod(alphas, axis=0)
self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
def q_sample(self, y0, t, noise):
return (
extract_into_tensor(self.sqrt_alphas_cumprod, t, y0.shape) * y0
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, y0.shape) * noise
)
def training_losses(self, x, y, t):
noise = torch.randn_like(y)
y_t = self.q_sample(y, t, noise)
predict_noise = self.model(torch.cat([x, y_t], dim=1), t)
return F.mse_loss(predict_noise, noise)
@torch.no_grad()
def ddpm(self, x, device):
y = torch.randn(*x.shape, device=device)
for t in tqdm(reversed(range(0, self.denoise_steps)), total=self.denoise_steps):
t = torch.tensor([t], device=device).repeat(x.shape[0])
t_mask = (t != 0).float().view(-1, *([1] * (len(y.shape) - 1)))
eps = self.model(torch.cat([x, y], dim=1), t)
y = y - (
extract_into_tensor(self

最低0.47元/天 解锁文章
3161

被折叠的 条评论
为什么被折叠?



