import os
import numpy as np
from osgeo import gdal, osr
from scipy.stats import gamma, norm
from datetime import datetime
# 设置输入输出路径
input_dir = r'F:\converted_clipped\pre\CN05'
output_dir = r'F:\converted_clipped\spi\CN_SPI1_monthly'
os.makedirs(output_dir, exist_ok=True)
month_days = [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
gdal.UseExceptions()
def is_leap(year):
"""判断闰年:$year \% 4 = 0 \land (year \% 100 \neq 0 \lor year \% 400 = 0)$"""
return (year % 4 == 0 and year % 100 != 0) or (year % 400 == 0)
def read_tif(file_path):
"""读取多波段TIFF文件"""
ds = gdal.Open(file_path)
band_count = ds.RasterCount
cols = ds.RasterXSize
rows = ds.RasterYSize
geo_transform = ds.GetGeoTransform()
projection = ds.GetProjection()
nodata_value = ds.GetRasterBand(1).GetNoDataValue()
data_array = np.zeros((rows, cols, band_count), dtype=np.float32)
for i in range(band_count):
arr = ds.GetRasterBand(i+1).ReadAsArray().astype(np.float32)
arr[arr == nodata_value] = np.nan
data_array[:, :, i] = arr
return data_array, geo_transform, projection, nodata_value, cols, rows
def write_tif(output_path, data, geo_transform, projection, nodata_value):
"""输出SPI结果"""
driver = gdal.GetDriverByName("GTiff")
rows, cols = data.shape
out_ds = driver.Create(output_path, cols, rows, 1, gdal.GDT_Float32)
out_ds.SetGeoTransform(geo_transform)
out_ds.SetProjection(projection)
out_band = out_ds.GetRasterBand(1)
out_band.SetNoDataValue(nodata_value)
out_band.WriteArray(data)
out_band.FlushCache()
del out_ds
def gamma_params(hist_data):
"""Gamma分布参数估计:$$X \sim \Gamma(\alpha, \beta)$$"""
valid_hist = hist_data[~np.isnan(hist_data)]
non_zero = valid_hist[valid_hist > 0]
if len(non_zero) < 5:
return None, None, None
try:
shape, loc, scale = gamma.fit(non_zero, floc=0)
zero_prob = np.mean(valid_hist == 0)
return zero_prob, shape, scale
except:
return None, None, None
def compute_spi(month_data, hist_data):
valid_hist = hist_data[~np.isnan(hist_data)]
if len(valid_hist) < 30: # 至少需要30个样本
return np.nan
# 零降水处理
zero_mask = (valid_hist == 0)
zero_prob = np.mean(zero_mask)
non_zero = valid_hist[~zero_mask]
if len(non_zero) < 5 or zero_prob >= 1.0:
return np.nan
try:
# Gamma参数估计
shape, loc, scale = gamma.fit(non_zero, floc=0)
if shape <= 0 or scale <= 0:
return np.nan
# 累积概率计算
total = np.nansum(month_data)
if total == 0:
prob = zero_prob
else:
gamma_prob = gamma.cdf(total, a=shape, loc=0, scale=scale)
prob = zero_prob + (1 - zero_prob) * gamma_prob
# 概率裁剪
prob = np.clip(prob, 1e-9, 1-1e-9)
return norm.ppf(prob)
except Exception:
return np.nan
def process_month(year, month, data_slice, all_hist, geo_transform, projection, nodata):
"""处理单个月份的SPI计算"""
spi_grid = np.full(data_slice.shape[:2], np.nan)
# 逐像元计算
for y in range(data_slice.shape[0]):
for x in range(data_slice.shape[1]):
month_total = np.nansum(data_slice[y, x, :])
hist_data = all_hist[month][y, x, :]
params = gamma_params(hist_data)
spi_grid[y, x] = compute_spi(month_total, params)
# 输出结果
output_file = os.path.join(output_dir, f"CN_SPI_{year}{month+1:02d}.tif")
write_tif(output_file, spi_grid, geo_transform, projection, nodata)
print(f"生成 {output_file}")
def main(start_year=1961, end_year=2018):
# 步骤1:收集历史数据
monthly_hist = [np.zeros((0,0,0))] * 12 # 初始化三维数组
for year in range(start_year, end_year+1):
file_path = os.path.join(input_dir, f"CN_pre_CN_{year}.tif")
if not os.path.exists(file_path):
continue
data_array, _, _, _, cols, rows = read_tif(file_path)
leap = is_leap(year)
days = month_days.copy()
if leap: days[1] = 29
# 按月份切片
day_ptr = 0
for m in range(12):
end_day = day_ptr + days[m]
monthly_slice = data_array[:, :, day_ptr:end_day]
if monthly_hist[m].size == 0:
monthly_hist[m] = monthly_slice
else:
monthly_hist[m] = np.concatenate([monthly_hist[m], monthly_slice], axis=2)
day_ptr = end_day
# 步骤2:计算各年份SPI
for year in range(start_year, end_year+1):
file_path = os.path.join(input_dir, f"CN_pre_CN_{year}.tif")
data_array, geo_transform, projection, nodata, cols, rows = read_tif(file_path)
leap = is_leap(year)
days = month_days.copy()
if leap: days[1] = 29
day_ptr = 0
for m in range(12):
end_day = day_ptr + days[m]
monthly_slice = data_array[:, :, day_ptr:end_day]
process_month(year, m, monthly_slice, monthly_hist, geo_transform, projection, nodata)
day_ptr = end_day
if __name__ == "__main__":
main()
在此基础上显示完善之后的代码