1、前言
苦逼本科生搞毕设,需要对光学遥感卫星做超分辨率算法,腾讯元宝给我推荐了这个模型。
论文链接:OpenSR在IEEE上的论文
GitHub源代码:GitHub - ESAOpenSR/opensr-model
值得一提的是,该模型针对的是10m分辨率的,而我其实需要的是3m分辨率的。奈何网上的模型太杂,非常难找到针对性的模型,且我没有合适的数据集用来训练,只能优先寻找包含预训练模型的开源工程,所以可选择的就有限了。
2、配置环境
根据requirements.txt文件配置环境,由于之前没有经验,所以torch环境搞了半天。
具体操作流程,参考我的这篇博客:(还在审核中)
总之,搞了半天,终于按照requirements的要求,配置了适合的torch版本了。
3、opensr_model出现问题
报错如下:该模型的ssl文件出现了问题
询问腾讯元宝后,得到可能是OpenSSL DLL缺失,需要安装OpenSSL并修复DLL路径
但是,千万别信!千万别信!千万别信!我被坑惨了!
解决方法其实很简单,参考这篇博客即可:ImportError: DLL load failed while importing _ssl: 找不到指定的模块。-优快云博客
4、最终结果:寄!
毫无心气的寄了,没辙,搞了两天,也只能换一个model了。
以下是豆包给出的原因,我估计主要还是我的图像分辨率为3.2m,而模型里给的是10m。
不过按理说,哪怕分辨率有差距,也不至于是乱码。个人感觉恐怕大概率还是输入的时候哪里出了问题吧。
我的代码也是在豆包帮助下弄的,既然结果天差地别,就放出来仅供参考吧。至少他还能稳定运行:
import os
import torch
from PIL import Image
import numpy as np
import opensr_model
from opensr_model import SRLatentDiffusion
import matplotlib.pyplot as plt
import tifffile
# 配置参数(根据实际需求调整)
MODEL_TYPE = "10m" # 注意:此处仍需与模型实际输入匹配
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")
CHECKPOINT_PATH = r"F:\python_project\LDSR-S2\opensr_10m_v4_v4.ckpt"
INPUT_IMAGE = r"F:\python_project\LDSR-S2\opensr-model-main\my_input\image1.tif" # 必须为四通道图像
OUTPUT_DIR = "my_output"
os.makedirs(OUTPUT_DIR, exist_ok=True)
# 加载模型(关键:确保model.bands与输入通道数一致)
model = SRLatentDiffusion(bands=MODEL_TYPE, device=DEVICE)
model.load_state_dict(torch.load(CHECKPOINT_PATH, map_location=DEVICE), strict=False)
model.eval()
# 定义四通道图像读取函数
def read_image(img_path):
img = tifffile.imread(img_path) # 读取四通道的 TIF 文件
img = torch.from_numpy(img).float().permute(2, 0, 1) # 转换为 PyTorch 张量
return img.unsqueeze(0).to(DEVICE)
# 主流程
try:
# 加载图像(四通道)
input_img = read_image(INPUT_IMAGE)
print("Input Image Shape:", input_img.shape) # (1, 4, H, W)
# 验证输入维度
assert input_img.shape[1] == 4, "输入图像通道数错误!"
# 提取 RGB 三个通道
rgb = input_img[0, :3, :, :].cpu().numpy()
rgb = np.clip(rgb, 0, 255).astype(np.uint8)
# 显示 RGB 图像
plt.figure(figsize=(10, 5))
plt.imshow(rgb.transpose(1, 2, 0))
plt.title("RGB Image from Input")
plt.axis("off")
plt.show()
print("已显示输入图像的 RGB 通道。")
# 如果你还想继续进行超分推理,可以保留以下代码
# LDSR-S2超分推理
with torch.no_grad():
sr = model(input_img)
sr = sr.squeeze().cpu().numpy()
assert sr.shape[0] == 4, "超分结果通道数错误!"
sr = np.clip(sr, 0, 255).astype(np.uint8)
# 保存 BGRN 四通道的 TIF 文件
tifffile.imwrite(os.path.join(OUTPUT_DIR, "superres_bgrn.tif"), sr.transpose(1, 2, 0))
# 提取 RGB 三个通道并保存为 RGB 图像
rgb = sr[:3, :, :]
Image.fromarray(rgb.transpose(1, 2, 0)).convert("RGB").save(
os.path.join(OUTPUT_DIR, "superres_rgb.png")
)
# 显示 RGB 超分结果
plt.figure(figsize=(10, 5))
plt.imshow(rgb.transpose(1, 2, 0))
plt.title("Super-Resolution RGB Result")
plt.axis("off")
plt.show()
print("超分完成!结果保存在 my_output 文件夹中。")
except FileNotFoundError as e:
print(f"文件未找到错误: {str(e)}")
except AssertionError as e:
print(f"维度验证失败: {str(e)}")
except Exception as e:
print(f"未知错误: {str(e)}")