ESAOpenSR的GitHub工程复现

1、前言

苦逼本科生搞毕设,需要对光学遥感卫星做超分辨率算法,腾讯元宝给我推荐了这个模型。

论文链接:OpenSR在IEEE上的论文

GitHub源代码:GitHub - ESAOpenSR/opensr-model

值得一提的是,该模型针对的是10m分辨率的,而我其实需要的是3m分辨率的。奈何网上的模型太杂,非常难找到针对性的模型,且我没有合适的数据集用来训练,只能优先寻找包含预训练模型的开源工程,所以可选择的就有限了。

2、配置环境

根据requirements.txt文件配置环境,由于之前没有经验,所以torch环境搞了半天。

具体操作流程,参考我的这篇博客:(还在审核中)

总之,搞了半天,终于按照requirements的要求,配置了适合的torch版本了。

3、opensr_model出现问题

报错如下:该模型的ssl文件出现了问题

询问腾讯元宝后,得到可能是​OpenSSL DLL缺失,需要安装OpenSSL并修复DLL路径

但是,千万别信!千万别信!千万别信!我被坑惨了!

解决方法其实很简单,参考这篇博客即可:ImportError: DLL load failed while importing _ssl: 找不到指定的模块。-优快云博客

4、最终结果:寄!

毫无心气的寄了,没辙,搞了两天,也只能换一个model了。

以下是豆包给出的原因,我估计主要还是我的图像分辨率为3.2m,而模型里给的是10m。

不过按理说,哪怕分辨率有差距,也不至于是乱码。个人感觉恐怕大概率还是输入的时候哪里出了问题吧。

我的代码也是在豆包帮助下弄的,既然结果天差地别,就放出来仅供参考吧。至少他还能稳定运行:

import os
import torch
from PIL import Image
import numpy as np
import opensr_model
from opensr_model import SRLatentDiffusion
import matplotlib.pyplot as plt
import tifffile

# 配置参数(根据实际需求调整)
MODEL_TYPE = "10m"         # 注意:此处仍需与模型实际输入匹配
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

CHECKPOINT_PATH = r"F:\python_project\LDSR-S2\opensr_10m_v4_v4.ckpt"
INPUT_IMAGE = r"F:\python_project\LDSR-S2\opensr-model-main\my_input\image1.tif"  # 必须为四通道图像
OUTPUT_DIR = "my_output"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# 加载模型(关键:确保model.bands与输入通道数一致)
model = SRLatentDiffusion(bands=MODEL_TYPE, device=DEVICE)
model.load_state_dict(torch.load(CHECKPOINT_PATH, map_location=DEVICE), strict=False)
model.eval()

# 定义四通道图像读取函数
def read_image(img_path):
    img = tifffile.imread(img_path)  # 读取四通道的 TIF 文件
    img = torch.from_numpy(img).float().permute(2, 0, 1)  # 转换为 PyTorch 张量
    return img.unsqueeze(0).to(DEVICE)

# 主流程
try:
    # 加载图像(四通道)
    input_img = read_image(INPUT_IMAGE)
    print("Input Image Shape:", input_img.shape)  # (1, 4, H, W)

    # 验证输入维度
    assert input_img.shape[1] == 4, "输入图像通道数错误!"

    # 提取 RGB 三个通道
    rgb = input_img[0, :3, :, :].cpu().numpy()
    rgb = np.clip(rgb, 0, 255).astype(np.uint8)

    # 显示 RGB 图像
    plt.figure(figsize=(10, 5))
    plt.imshow(rgb.transpose(1, 2, 0))
    plt.title("RGB Image from Input")
    plt.axis("off")
    plt.show()

    print("已显示输入图像的 RGB 通道。")

    # 如果你还想继续进行超分推理,可以保留以下代码
    # LDSR-S2超分推理
    with torch.no_grad():
        sr = model(input_img)

    sr = sr.squeeze().cpu().numpy()
    assert sr.shape[0] == 4, "超分结果通道数错误!"
    sr = np.clip(sr, 0, 255).astype(np.uint8)

    # 保存 BGRN 四通道的 TIF 文件
    tifffile.imwrite(os.path.join(OUTPUT_DIR, "superres_bgrn.tif"), sr.transpose(1, 2, 0))

    # 提取 RGB 三个通道并保存为 RGB 图像
    rgb = sr[:3, :, :]
    Image.fromarray(rgb.transpose(1, 2, 0)).convert("RGB").save(
        os.path.join(OUTPUT_DIR, "superres_rgb.png")
    )

    # 显示 RGB 超分结果
    plt.figure(figsize=(10, 5))
    plt.imshow(rgb.transpose(1, 2, 0))
    plt.title("Super-Resolution RGB Result")
    plt.axis("off")
    plt.show()

    print("超分完成!结果保存在 my_output 文件夹中。")

except FileNotFoundError as e:
    print(f"文件未找到错误: {str(e)}")
except AssertionError as e:
    print(f"维度验证失败: {str(e)}")
except Exception as e:
    print(f"未知错误: {str(e)}")

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值