参考github代码链接:https://github.com/shx951104/remote-sensing-images-fusion/blob/b4b4147c7896468516bd84c544a98270cd26589b/starfm_torch.py
稍稍做了修改
输入:
一组同一日期的高分辨率和低分辨率数据s1和l1
t2时刻的低分辨率数据l2
输出t2时刻的高分辨率数据s2
# -*- coding: utf-8 -*-
"""
Created on Tue Mar 17 15:15:36 2020
@author: Administrator
"""
import numpy as np
import torch
import torch.nn as nn
import time
#import skimage.measure as sm
import skimage.metrics as sm
import cv2
from osgeo import gdal,gdalconst
import matplotlib.pyplot as plt
import skimage.io as io
from skimage.transform import resize
from utils import *
###weight caculate tools######################################################
def weight_caculate(data):
return torch.log((abs(data)*10000+1.00001))
def caculate_weight(s1l1,l1l2):
#atmos difference
ws1l1=weight_caculate(s1l1 )
#time deference
wl1l2=weight_caculate(l1l2 )
return ws1l1*wl1l2
###space distance caculate tool################################################
def indexdistance(window):
#one window, one distance weight matrix
[distx,disty]=np.meshgrid(np.arange(window[0]),np.arange(window[1]))
centerlocx,centerlocy=(window[0]-1)//2,(window[1]-1)//2
dist=1+(((distx-centerlocx)**2+(disty-centerlocy)**2)**0.5)/((window[0]-1)//2)
return dist
###threshold select tool######################################################
def weight_bythreshold(weight,data,threshold):
#make weight tensor
weight[data<=threshold]=1
return weight
def weight_bythreshold_allbands(weight,l1m1,m1m2,thresholdmax):
#make weight tensor
weight[l1m1<=thresholdmax[0]]=1
weight[m1m2<=thresholdmax[1]]=1
allweight=(weight.sum(0).view(1,weight.shape[1],weight.shape[2]))/weight.shape[0]
allweight[allweight!=1]=0
return allweight
###initial similar pixels tools################################################
def spectral_similar_threshold(clusters,NIR,red):
thresholdNIR=NIR.std()*2/clusters
thresholdred=red.std()*2/clusters
return (thresholdNIR,thresholdred)
def caculate_similar(l1,threshold,window):
#read l1
device= torch.device( "cpu")
l1=nn.functional.unfold(l1,window)
#caculate similar
weight=torch.zeros(l1.shape,dtype=torch.float32).to(device)
centerloc=( l1.size()[1]-1)//2
weight=weight_bythreshold(weight,abs(l1-l1[:,centerloc:centerloc+1,:]) ,threshold)
return weight
def classifier(l1):
'''not used'''
return
###similar pixels filter tools#################################################
def allband_arrayindex(arraylist,indexarray,rawindexshape):
device= torch.device( "cpu")
shape=arraylist[0].shape
datalist=[]
for array in arraylist:
newarray=torch.zeros(rawindexshape,dtype=torch.float32).to(device)
for band in