# 修复版沙尘暴XGBoost+SHAP分析 - 高分辨率英文版
# 主要改进:
# 1. 所有图表输出为900 DPI高分辨率
# 2. 所有文字标签改为英文
# 3. 按照1区论文格式优化图表样式
# 4. 包含滞后期特征分析
# 5. 修复SHAP蜂群图参数问题
# ==================================================
# 1. 安装并加载所需包
cat("Installing and loading required packages...\n")
required_packages <- c('xgboost', 'caret', 'Metrics', 'tibble', 'dplyr',
'ggplot2', 'pROC', 'viridis', 'ggpubr', 'ggbeeswarm',
'terra', 'tidyverse', 'lubridate', 'patchwork', 'rsample',
'pROC', 'MLmetrics', 'randomForest', 'sf')
for(pkg in required_packages) {
if(!require(pkg, character.only = TRUE, quietly = TRUE)) {
install.packages(pkg)
library(pkg, character.only = TRUE)
}
}
# 2. 设置工作目录到I盘
setwd("I:/")
cat("Working directory set to:", getwd(), "\n")
# 3. 统一分辨率函数
resample_to_common_resolution <- function(data_list) {
cat("Resampling data to common resolution...\n")
# 使用DCMD数据作为模板(50km分辨率)
if("dcmd" %in% names(data_list)) {
template <- data_list$dcmd[[1]]
cat("Using DCMD data as resolution template (50km)\n")
} else {
# 如果没有DCMD,使用第一个可用变量
first_var <- names(data_list)[1]
template <- data_list[[first_var]][[1]]
cat("Using", first_var, "data as resolution template\n")
}
cat("Target resolution:", round(res(template)[1] * 111, 2), "km\n")
# 重采样各变量到统一分辨率
resampled_data <- list()
for(var_name in names(data_list)) {
cat("Processing variable:", var_name, "\n")
var_data <- data_list[[var_name]]
if(is.null(var_data)) next
tryCatch({
# 如果已经是DCMD数据,直接使用
if(var_name == "dcmd") {
resampled_data[[var_name]] <- var_data
cat(" -> Keeping original resolution\n")
} else {
# 其他变量重采样到统一分辨率
resampled_data[[var_name]] <- resample(var_data, template, method = "bilinear")
cat(" -> Resampling completed\n")
}
}, error = function(e) {
cat(" -> Resampling failed:", e$message, "\n")
})
}
return(resampled_data)
}
# 4. 数据读取函数
read_all_dust_data_annual <- function() {
cat("Reading annual dust and related environmental data...\n")
# 定义时间序列 (2005-2024)
time_index <- seq(as.Date("2005-01-01"), as.Date("2024-12-01"), by = "month")
years <- year(time_index)
months <- sprintf("%02d", month(time_index))
# 初始化数据存储
all_data <- list()
# 读取函数
read_raster_data <- function(file_pattern, var_name) {
files <- file_pattern
existing_files <- files[file.exists(files)]
if(length(existing_files) == 0) {
cat("Warning: No files found for", var_name, "\n")
return(NULL)
}
sorted_files <- sort(existing_files)
raster_data <- tryCatch({
rast(sorted_files)
}, error = function(e) {
cat("Reading", var_name, "failed:", e$message, "\n")
return(NULL)
})
return(raster_data)
}
# 读取沙尘柱质量密度数据 (目标变量)
cat("Reading dust column mass density data...\n")
dcmd_files <- file.path("H:/相关数据/2005-2024年北方风沙带MERRA2沙尘柱质量密度",
paste0("masked_DCMD_", years, "_", months, ".tif"))
all_data$dcmd <- read_raster_data(dcmd_files, "dcmd")
# 读取其他变量数据
cat("Reading precipitation data...\n")
precip_files <- file.path("H:/相关数据/2005-2024北方风沙带月尺度降水",
paste0("masked_ERA5Land_MonthlyPrecip_", years, months, ".tif"))
all_data$precip <- read_raster_data(precip_files, "precip")
cat("Reading strong wind hours data...\n")
wind_files <- file.path("H:/相关数据/2005-2024年北方风沙带月尺度大风总小时数",
paste0("masked_high_wind_hours_", years, "_", months, ".tif"))
all_data$wind_hours <- read_raster_data(wind_files, "wind_hours")
cat("Reading temperature data...\n")
temp_files <- file.path("H:/相关数据/2005-2024北方风沙带2m气温",
paste0("masked_ERA5Land_MeanTemp_", years, months, ".tif"))
all_data$temp <- read_raster_data(temp_files, "temp")
cat("Reading LST data...\n")
lst_files <- file.path("H:/相关数据/2005-2024年月尺度北方风沙带LST",
paste0("masked_LST_MONTHLY_", years, "_", months, ".tif"))
all_data$lst <- read_raster_data(lst_files, "lst")
cat("Reading NDVI data...\n")
ndvi_files <- file.path("H:/相关数据/2005-2024北方风沙带月尺度NDVI",
paste0("masked_NDVI_", years, "-", months, ".tif"))
all_data$ndvi <- read_raster_data(ndvi_files, "ndvi")
cat("Reading LAI data...\n")
lai_files <- file.path("H:/相关数据/2005-2024年月尺度北方风沙带LAI",
paste0("masked_LAI_MONTHLY_", years, "_", months, ".tif"))
all_data$lai <- read_raster_data(lai_files, "lai")
cat("Reading soil moisture data...\n")
sm_files <- file.path("H:/相关数据/2005-2024年月尺度北方风沙带根区土壤湿度SMRoot-GLEAM4",
paste0("masked_", years, "_", months, ".tif"))
all_data$sm <- read_raster_data(sm_files, "sm")
cat("Reading evapotranspiration data...\n")
et_files <- file.path("H:/相关数据/2005-2024年月尺度北方风沙带ET-GLEAM4",
paste0("masked_", years, "_", months, ".tif"))
all_data$et <- read_raster_data(et_files, "et")
cat("Reading VHI data...\n")
vhi_files <- file.path("H:/论文/沙尘暴特征/data/VHI计算",
paste0("VHI_", years, "_", months, ".tif"))
all_data$vhi <- read_raster_data(vhi_files, "vhi")
# 移除为NULL的数据
all_data <- all_data[!sapply(all_data, is.null)]
cat("Successfully read", length(all_data), "variables\n")
cat("Variables include:", paste(names(all_data), collapse = ", "), "\n")
return(list(data = all_data, time_index = time_index))
}
# 5. 区域级数据预处理函数
preprocess_regional_data <- function(data_list, time_index) {
cat("Preprocessing regional-level data...\n")
# 计算区域月平均值
regional_monthly <- data.frame()
for(var_name in names(data_list)) {
var_data <- data_list[[var_name]]
if(is.null(var_data)) next
# 按月份计算区域平均值
for(i in 1:nlyr(var_data)) {
if(i > length(time_index)) break
monthly_mean <- tryCatch({
global(var_data[[i]], mean, na.rm = TRUE)[1,1]
}, error = function(e) {
NA
})
if(!is.na(monthly_mean)) {
regional_monthly <- rbind(regional_monthly,
data.frame(
year = year(time_index[i]),
month = month(time_index[i]),
date = time_index[i],
variable = var_name,
value = monthly_mean
))
}
}
}
# 转换为宽格式
wide_data <- regional_monthly %>%
pivot_wider(names_from = variable, values_from = value) %>%
arrange(date)
# 移除包含NA的行
wide_data_clean <- wide_data[complete.cases(wide_data), ]
cat("Regional data preprocessing completed, dimensions:", dim(wide_data_clean), "\n")
cat("Time range:", min(wide_data_clean$date), "to", max(wide_data_clean$date), "\n")
cat("Available variables:", paste(setdiff(colnames(wide_data_clean), c("year", "month", "date")), collapse = ", "), "\n")
return(wide_data_clean)
}
# 6. 修复的像元级数据预处理函数
preprocess_pixel_level_data <- function(data_list, time_index, max_pixels_per_month = 100) {
cat("Processing pixel-level data...\n")
# 初始化存储
pixel_data_list <- list()
# 获取有效像元掩模(基于第一个变量)
first_var <- names(data_list)[1]
template <- data_list[[first_var]][[1]]
# 安全地获取坐标
cat("Getting spatial coordinates...\n")
tryCatch({
# 获取模板的像元总数
total_cells <- ncell(template)
cat("Total template cells:", total_cells, "\n")
# 获取有效像元(非NA)
cat("Identifying valid cells...\n")
valid_cells <- which(!is.na(values(template[[1]])))
cat("Number of valid cells:", length(valid_cells), "\n")
if(length(valid_cells) == 0) {
cat("Warning: No valid cells found\n")
return(NULL)
}
# 随机选择部分像元以避免内存问题
set.seed(123)
sampled_cells <- sample(valid_cells, min(max_pixels_per_month, length(valid_cells)))
cat("Sampled cells count:", length(sampled_cells), "\n")
# 使用xyFromCell安全获取坐标
sampled_coords <- xyFromCell(template, sampled_cells)
cat("Sampled coordinates dimension:", dim(sampled_coords), "\n")
}, error = function(e) {
cat("Failed to get coordinates:", e$message, "\n")
return(NULL)
})
# 处理所有时间步,但限制数量以避免内存问题
total_time_steps <- length(time_index)
max_time_steps <- min(240, total_time_steps) # 先处理前240个时间步(5年数据)
cat(sprintf("Processing %d/%d time steps...\n", max_time_steps, total_time_steps))
# 提取每个时间步的像元数据
for(i in 1:max_time_steps) {
current_date <- time_index[i]
current_year <- year(current_date)
current_month <- month(current_date)
cat(sprintf("Processing time step %d/%d: %s (Year: %d, Month: %d)\n",
i, max_time_steps, current_date, current_year, current_month))
monthly_pixels <- data.frame(
pixel_id = sampled_cells,
year = current_year,
month = current_month,
date = current_date,
x = sampled_coords[, 1],
y = sampled_coords[, 2]
)
# 提取各变量值
for(var_name in names(data_list)) {
var_data <- data_list[[var_name]]
if(!is.null(var_data) && nlyr(var_data) >= i) {
tryCatch({
# 使用extract安全提取值
extracted_values <- terra::extract(var_data[[i]], sampled_cells)
if(ncol(extracted_values) == 2) {
values <- extracted_values[, 2]
} else {
values <- extracted_values[, 1]
}
monthly_pixels[[var_name]] <- values
}, error = function(e) {
cat(" Failed to extract variable", var_name, ":", e$message, "\n")
monthly_pixels[[var_name]] <- NA
})
} else {
monthly_pixels[[var_name]] <- NA
cat(" Variable", var_name, "data insufficient or NULL\n")
}
}
pixel_data_list[[i]] <- monthly_pixels
# 每处理5个时间步就清理内存
if(i %% 5 == 0) {
gc()
cat(" Memory cleanup completed\n")
}
}
# 合并所有时间步数据
pixel_data <- bind_rows(pixel_data_list)
# 移除包含NA的行
original_rows <- nrow(pixel_data)
pixel_data_clean <- pixel_data[complete.cases(pixel_data), ]
cat("Pixel-level data preprocessing completed\n")
cat("Original data rows:", original_rows, "\n")
cat("Cleaned data rows:", nrow(pixel_data_clean), "\n")
cat("Total samples:", nrow(pixel_data_clean), "\n")
cat("Time range:", min(pixel_data_clean$date), "to", max(pixel_data_clean$date), "\n")
cat("Years included:", paste(sort(unique(pixel_data_clean$year)), collapse = ", "), "\n")
return(pixel_data_clean)
}
# 7. 滞后期特征构建函数
create_lagged_features <- function(analysis_data, lag_months = c(0, 5)) {
cat("Creating lagged features...\n")
# 按像元分组处理时间序列
lagged_data <- analysis_data %>%
arrange(pixel_id, date) %>%
group_by(pixel_id) %>%
# 为生态水文因子创建滞后期
mutate(
# 即时因子(保持原样)
wind_hours_0 = wind_hours,
vhi_0 = vhi,
# 关键滞后期生态水文因子(5个月)
precip_lag5 = lag(precip, 5),
temp_lag5 = lag(temp, 5),
ndvi_lag5 = lag(ndvi, 5),
lai_lag5 = lag(lai, 5),
lst_lag5 = lag(lst, 5),
sm_lag5 = lag(sm, 5),
et_lag5 = lag(et, 5),
vhi_lag5 = lag(vhi, 5),
# 其他滞后期(用于对比分析)
precip_lag3 = lag(precip, 3),
ndvi_lag3 = lag(ndvi, 3),
sm_lag3 = lag(sm, 3)
) %>%
ungroup() %>%
# 移除由于滞后期产生的NA行
filter(!is.na(precip_lag5) & !is.na(ndvi_lag5))
cat(sprintf("Lagged features created, samples reduced from %d to %d\n",
nrow(analysis_data), nrow(lagged_data)))
cat("New features:", paste(grep("_lag", colnames(lagged_data), value = TRUE), collapse = ", "), "\n")
return(lagged_data)
}
# 8. 创建沙尘暴分类标签
create_dust_storm_labels <- function(data, dcmd_threshold_method = "quantile", threshold_value = 0.75) {
cat("Creating dust storm classification labels...\n")
dcmd_values <- data$dcmd
if(dcmd_threshold_method == "quantile") {
threshold <- quantile(dcmd_values, probs = threshold_value, na.rm = TRUE)
data$dust_storm <- ifelse(data$dcmd >= threshold, 1, 0)
cat(sprintf("Using quantile threshold: %.4f (%.0f%% quantile)\n", threshold, threshold_value * 100))
} else if(dcmd_threshold_method == "sd") {
mean_val <- mean(dcmd_values, na.rm = TRUE)
sd_val <- sd(dcmd_values, na.rm = TRUE)
threshold <- mean_val + sd_val
data$dust_storm <- ifelse(data$dcmd >= threshold, 1, 0)
cat(sprintf("Using standard deviation threshold: %.4f (mean + 1SD)\n", threshold))
} else {
data$dust_storm <- ifelse(data$dcmd >= threshold_value, 1, 0)
cat(sprintf("Using fixed threshold: %.4f\n", threshold_value))
}
# 统计分类结果
storm_count <- sum(data$dust_storm)
non_storm_count <- sum(data$dust_storm == 0)
total_count <- nrow(data)
cat(sprintf("Dust storm samples: %d (%.1f%%)\n", storm_count, storm_count/total_count*100))
cat(sprintf("Non-dust storm samples: %d (%.1f%%)\n", non_storm_count, non_storm_count/total_count*100))
return(data)
}
# 9. 使用随机7:3划分训练集和测试集
split_train_test_data <- function(data, test_size = 0.3, target_var = "dust_storm") {
cat("Splitting data into training and test sets (7:3)...\n")
set.seed(123)
# 确保目标变量存在
if(!target_var %in% colnames(data)) {
stop("Target variable does not exist: ", target_var)
}
# 使用分层抽样确保两个集合都有沙尘暴样本
train_idx <- createDataPartition(data[[target_var]], p = 1 - test_size, list = FALSE)
train_data <- data[train_idx, ]
test_data <- data[-train_idx, ]
# 检查划分结果
cat(sprintf("Training set: %d samples\n", nrow(train_data)))
cat(sprintf("Test set: %d samples\n", nrow(test_data)))
# 检查类别分布
train_storm_ratio <- mean(train_data[[target_var]])
test_storm_ratio <- mean(test_data[[target_var]])
cat(sprintf("Training set dust storm ratio: %.2f%%\n", train_storm_ratio * 100))
cat(sprintf("Test set dust storm ratio: %.2f%%\n", test_storm_ratio * 100))
# 如果测试集没有沙尘暴样本,重新划分
if(test_storm_ratio == 0) {
cat("Warning: Test set has no dust storm samples, re-splitting...\n")
return(split_train_test_data_manual(data, test_size, target_var))
}
return(list(
train_data = train_data,
test_data = test_data
))
}
# 10. 手动划分函数 - 确保测试集有沙尘暴样本
split_train_test_data_manual <- function(data, test_size = 0.3, target_var = "dust_storm") {
cat("Using manual splitting to ensure test set diversity...\n")
set.seed(123)
# 分离沙尘暴和非沙尘暴样本
storm_data <- data[data[[target_var]] == 1, ]
non_storm_data <- data[data[[target_var]] == 0, ]
# 计算测试集需要的样本数
n_test_storm <- max(1, round(nrow(storm_data) * test_size))
n_test_non_storm <- round(nrow(non_storm_data) * test_size)
# 从沙尘暴样本中随机选择测试样本
test_storm_idx <- sample(1:nrow(storm_data), n_test_storm)
test_non_storm_idx <- sample(1:nrow(non_storm_data), n_test_non_storm)
# 创建测试集
test_data <- rbind(
storm_data[test_storm_idx, ],
non_storm_data[test_non_storm_idx, ]
)
# 创建训练集
train_data <- rbind(
storm_data[-test_storm_idx, ],
non_storm_data[-test_non_storm_idx, ]
)
# 打乱顺序
train_data <- train_data[sample(1:nrow(train_data)), ]
test_data <- test_data[sample(1:nrow(test_data)), ]
# 检查结果
cat(sprintf("Training set: %d samples (dust storms: %.2f%%)\n",
nrow(train_data), mean(train_data[[target_var]]) * 100))
cat(sprintf("Test set: %d samples (dust storms: %.2f%%)\n",
nrow(test_data), mean(test_data[[target_var]]) * 100))
return(list(
train_data = train_data,
test_data = test_data
))
}
# 11. 分类评估指标计算
calculate_classification_metrics <- function(y_train, y_test,
train_pred_class, test_pred_class,
train_pred_prob, test_pred_prob) {
# 训练集指标
train_accuracy <- sum(train_pred_class == y_train) / length(y_train)
# 检查是否有两个类别
if(length(unique(y_train)) == 2) {
train_precision <- precision(y_train, train_pred_class)
train_recall <- recall(y_train, train_pred_class)
train_f1 <- F1_Score(y_train, train_pred_class)
train_auc <- auc(roc(y_train, train_pred_prob))
} else {
train_precision <- NA
train_recall <- NA
train_f1 <- NA
train_auc <- NA
cat("Warning: Training set has only one class, cannot calculate precision, recall, F1 and AUC\n")
}
# 测试集指标
test_accuracy <- sum(test_pred_class == y_test) / length(y_test)
if(length(unique(y_test)) == 2) {
test_precision <- precision(y_test, test_pred_class)
test_recall <- recall(y_test, test_pred_class)
test_f1 <- F1_Score(y_test, test_pred_class)
test_auc <- auc(roc(y_test, test_pred_prob))
} else {
test_precision <- NA
test_recall <- NA
test_f1 <- NA
test_auc <- NA
cat("Warning: Test set has only one class, cannot calculate precision, recall, F1 and AUC\n")
}
# 混淆矩阵
train_cm <- table(Actual = y_train, Predicted = train_pred_class)
test_cm <- table(Actual = y_test, Predicted = test_pred_class)
return(list(
train_accuracy = train_accuracy,
train_precision = train_precision,
train_recall = train_recall,
train_f1 = train_f1,
train_auc = train_auc,
test_accuracy = test_accuracy,
test_precision = test_precision,
test_recall = test_recall,
test_f1 = test_f1,
test_auc = test_auc,
train_cm = train_cm,
test_cm = test_cm
))
}
# 12. 打印分类结果
print_classification_results <- function(metrics) {
cat("\n=== Classification Model Evaluation Results ===\n")
cat("Training Set Performance:\n")
cat(sprintf(" Accuracy: %.4f\n", metrics$train_accuracy))
if(!is.na(metrics$train_precision)) {
cat(sprintf(" Precision: %.4f\n", metrics$train_precision))
cat(sprintf(" Recall: %.4f\n", metrics$train_recall))
cat(sprintf(" F1 Score: %.4f\n", metrics$train_f1))
cat(sprintf(" AUC: %.4f\n", metrics$train_auc))
} else {
cat(" Precision: N/A (only one class)\n")
cat(" Recall: N/A (only one class)\n")
cat(" F1 Score: N/A (only one class)\n")
cat(" AUC: N/A (only one class)\n")
}
cat("\nTest Set Performance:\n")
cat(sprintf(" Accuracy: %.4f\n", metrics$test_accuracy))
if(!is.na(metrics$test_precision)) {
cat(sprintf(" Precision: %.4f\n", metrics$test_precision))
cat(sprintf(" Recall: %.4f\n", metrics$test_recall))
cat(sprintf(" F1 Score: %.4f\n", metrics$test_f1))
cat(sprintf(" AUC: %.4f\n", metrics$test_auc))
} else {
cat(" Precision: N/A (only one class)\n")
cat(" Recall: N/A (only one class)\n")
cat(" F1 Score: N/A (only one class)\n")
cat(" AUC: N/A (only one class)\n")
}
cat("\nTraining Set Confusion Matrix:\n")
print(metrics$train_cm)
cat("\nTest Set Confusion Matrix:\n")
print(metrics$test_cm)
}
# 13. 特征集定义函数
define_feature_sets <- function() {
list(
# 仅即时特征(基准模型)
immediate_only = c("wind_hours", "vhi", "precip", "temp",
"ndvi", "lai", "lst", "sm", "et"),
# 仅滞后期特征
lagged_only = c("wind_hours_0", "vhi_0",
"precip_lag5", "temp_lag5", "ndvi_lag5", "lai_lag5",
"lst_lag5", "sm_lag5", "et_lag5", "vhi_lag5"),
# 混合特征(推荐)- 基于滞时分析结果
hybrid_optimal = c("wind_hours_0", "vhi_0", # 即时强相关因子
"temp_lag5", "et_lag5", "ndvi_lag5", # 滞时5个月强相关
"lst_lag5", "lai_lag5", "precip_lag5", # 滞时5个月重要因子
"sm_lag5") # 滞时5个月中等相关
)
}
# 14. 改进的XGBoost分类分析
perform_xgboost_classification_improved <- function(train_data, test_data,
target_var = "dust_storm",
feature_set_name = "hybrid_optimal") {
cat("Performing improved XGBoost classification analysis...\n")
cat("Using feature set:", feature_set_name, "\n")
set.seed(123)
# 获取特征集
feature_sets <- define_feature_sets()
if(!feature_set_name %in% names(feature_sets)) {
stop("Unknown feature set name: ", feature_set_name)
}
features <- feature_sets[[feature_set_name]]
# 检查特征有效性
valid_features <- c()
for(feature in features) {
if(feature %in% colnames(train_data) && sd(train_data[[feature]], na.rm = TRUE) > 0) {
valid_features <- c(valid_features, feature)
}
}
if(length(valid_features) == 0) {
stop("No valid features available for modeling")
}
cat(sprintf("Valid features: %d (original features: %d)\n", length(valid_features), length(features)))
cat("Features used:", paste(valid_features, collapse = ", "), "\n")
# 准备数据矩阵
X_train <- as.matrix(train_data[, valid_features])
X_test <- as.matrix(test_data[, valid_features])
y_train <- train_data[[target_var]]
y_test <- test_data[[target_var]]
# 数据标准化
preprocess_params <- preProcess(X_train, method = c("center", "scale"))
X_train_scaled <- predict(preprocess_params, X_train)
X_test_scaled <- predict(preprocess_params, X_test)
# 训练XGBoost分类模型
xgb_model <- xgboost(
data = X_train_scaled,
label = y_train,
max_depth = 3,
eta = 0.05,
nthread = 2,
nrounds = 150,
objective = "binary:logistic",
eval_metric = "logloss",
verbose = 0,
early_stopping_rounds = 25,
subsample = 0.7,
colsample_bytree = 0.7,
lambda = 2,
alpha = 0.5,
min_child_weight = 3
)
# 预测概率
train_pred_prob <- predict(xgb_model, X_train_scaled)
test_pred_prob <- predict(xgb_model, X_test_scaled)
# 转换为分类预测(阈值0.5)
train_pred_class <- ifelse(train_pred_prob > 0.5, 1, 0)
test_pred_class <- ifelse(test_pred_prob > 0.5, 1, 0)
# 计算评估指标
metrics <- calculate_classification_metrics(y_train, y_test,
train_pred_class, test_pred_class,
train_pred_prob, test_pred_prob)
# 打印评估结果
print_classification_results(metrics)
return(list(
model = xgb_model,
features = valid_features,
X_train = X_train_scaled,
X_test = X_test_scaled,
y_train = y_train,
y_test = y_test,
train_pred_prob = train_pred_prob,
test_pred_prob = test_pred_prob,
train_pred_class = train_pred_class,
test_pred_class = test_pred_class,
metrics = metrics,
feature_names = valid_features,
preprocess_params = preprocess_params,
feature_set_name = feature_set_name
))
}
# 15. 比较不同特征集的性能
compare_feature_sets <- function(train_data, test_data, target_var = "dust_storm") {
cat("Comparing model performance across different feature sets...\n")
feature_sets <- define_feature_sets()
results <- list()
for(set_name in names(feature_sets)) {
cat(sprintf("\n=== Training feature set: %s ===\n", set_name))
# 筛选特征
features <- feature_sets[[set_name]]
available_features <- features[features %in% colnames(train_data)]
if(length(available_features) > 0) {
# 训练模型
xgb_result <- perform_xgboost_classification_improved(
train_data, test_data, target_var, set_name
)
results[[set_name]] <- list(
features = available_features,
metrics = xgb_result$metrics,
model = xgb_result$model
)
} else {
cat(" No available features, skipping\n")
}
}
# 打印比较结果
cat("\n=== Feature Set Performance Comparison ===\n")
comparison_df <- data.frame()
for(set_name in names(results)) {
metrics <- results[[set_name]]$metrics
comparison_df <- rbind(comparison_df, data.frame(
Feature_Set = set_name,
Train_Accuracy = metrics$train_accuracy,
Test_Accuracy = metrics$test_accuracy,
Test_Precision = ifelse(is.na(metrics$test_precision), NA, metrics$test_precision),
Test_Recall = ifelse(is.na(metrics$test_recall), NA, metrics$test_recall),
Test_F1 = ifelse(is.na(metrics$test_f1), NA, metrics$test_f1),
Test_AUC = ifelse(is.na(metrics$test_auc), NA, metrics$test_auc),
Num_Features = length(results[[set_name]]$features)
))
}
print(comparison_df)
return(list(
results = results,
comparison = comparison_df
))
}
# 16. 期刊风格高分辨率蜂群图函数(最终优化版)
create_stable_beeswarm_plot <- function(shap_matrix, X_train, importance_data, output_dir) {
cat("Creating stable high-resolution beeswarm plot using standard ggplot...\n")
tryCatch({
# 1. 数据抽样与处理
max_samples <- min(800, nrow(shap_matrix))
set.seed(123)
sample_idx <- sample(1:nrow(shap_matrix), max_samples)
shap_df <- as.data.frame(shap_matrix[sample_idx, , drop = FALSE])
feature_df <- as.data.frame(X_train[sample_idx, , drop = FALSE])
# 2. 确保特征列名一致
common_features <- intersect(colnames(shap_df), colnames(feature_df))
if (length(common_features) == 0) {
stop("SHAP矩阵与特征矩阵无共同特征")
}
# 3. 生成长格式数据
plot_data <- data.frame()
for (f in common_features) {
if (f %in% colnames(shap_df) && f %in% colnames(feature_df)) {
temp <- data.frame(
feature = f,
shap_value = shap_df[[f]],
feature_value = feature_df[[f]],
stringsAsFactors = FALSE
)
plot_data <- rbind(plot_data, temp)
}
}
# 4. 处理因子水平
valid_features <- importance_data$feature[importance_data$feature %in% common_features]
if (length(valid_features) == 0) valid_features <- common_features
# 确保因子水平正确设置
plot_data$feature <- factor(plot_data$feature, levels = rev(valid_features))
# 移除无效数据
plot_data <- plot_data[!is.na(plot_data$shap_value) & !is.na(plot_data$feature_value), ]
if (nrow(plot_data) == 0) {
stop("No valid data after filtering NA values")
}
# 5. 创建稳定的蜂群图 - 使用geom_point + position_jitterdodge
p <- ggplot2::ggplot(
plot_data,
ggplot2::aes(x = shap_value, y = feature, color = feature_value)
) +
# 使用geom_point配合jitterdodge位置调整
ggplot2::geom_point(
position = ggplot2::position_jitterdodge(
jitter.width = 0.25, # 水平抖动
jitter.height = 0.15, # 垂直抖动
dodge.width = 0.8 # 避让宽度
),
alpha = 0.8,
size = 1.8
) +
# 添加参考线
ggplot2::geom_vline(
xintercept = 0,
linetype = "dashed",
color = "black",
linewidth = 0.7
) +
# 优化颜色映射
ggplot2::scale_color_gradient2(
low = "#1E88E5", # 鲜亮的蓝色
mid = "#FFFFFF", # 白色中间色
high = "#E53935", # 鲜亮的红色
midpoint = median(plot_data$feature_value, na.rm = TRUE),
name = "Feature Value",
guide = guide_colorbar(
barwidth = 12,
barheight = 0.8,
title.position = "top",
title.hjust = 0.5
)
) +
# 坐标轴与标题设置
ggplot2::labs(
x = "SHAP Value (Impact on Model Output)",
y = "Features",
title = "SHAP Feature Importance Beeswarm Plot",
subtitle = "Each point represents a sample's feature contribution"
) +
# 优化主题设置
ggplot2::theme_minimal(base_size = 14) +
ggplot2::theme(
plot.title = ggplot2::element_text(
hjust = 0.5,
face = "bold",
size = 16,
color = "black",
margin = margin(b = 10)
),
plot.subtitle = ggplot2::element_text(
hjust = 0.5,
size = 12,
color = "black",
margin = margin(b = 15)
),
axis.title.x = ggplot2::element_text(
face = "bold",
size = 14,
color = "black",
margin = margin(t = 10)
),
axis.title.y = ggplot2::element_text(
face = "bold",
size = 14,
color = "black",
margin = margin(r = 10)
),
axis.text.x = ggplot2::element_text(
size = 12,
color = "black",
face = "bold"
),
axis.text.y = ggplot2::element_text(
size = 12,
color = "black",
face = "bold"
),
# 只有水平网格线
panel.grid.major.x = ggplot2::element_blank(),
panel.grid.minor.x = ggplot2::element_blank(),
panel.grid.major.y = ggplot2::element_line(
color = "grey90",
linewidth = 0.4
),
panel.grid.minor.y = ggplot2::element_blank(),
axis.line.x = ggplot2::element_line(
color = "black",
linewidth = 0.6
),
legend.position = "bottom",
legend.title = ggplot2::element_text(face = "bold", size = 12),
legend.text = ggplot2::element_text(size = 10),
plot.background = ggplot2::element_rect(fill = "white", color = NA),
panel.background = ggplot2::element_rect(fill = "white", color = NA),
plot.margin = ggplot2::margin(20, 20, 20, 20)
) +
# 调整坐标轴范围
ggplot2::scale_x_continuous(
expand = ggplot2::expansion(mult = c(0.03, 0.03))
)
# 6. 保存高分辨率图像
shap_dir <- file.path(output_dir, "HighRes_SHAP_Results")
dir.create(shap_dir, showWarnings = FALSE, recursive = TRUE)
png_file <- file.path(shap_dir, "SHAP_Beeswarm_Stable.png")
ggplot2::ggsave(
png_file,
plot = p,
width = 11,
height = max(8, 0.6 * length(valid_features)),
dpi = 900,
bg = "white"
)
pdf_file <- file.path(shap_dir, "SHAP_Beeswarm_Stable.pdf")
ggplot2::ggsave(
pdf_file,
plot = p,
width = 11,
height = max(8, 0.6 * length(valid_features)),
device = cairo_pdf,
bg = "white"
)
cat("Stable beeswarm plot saved:", png_file, "and", pdf_file, "\n")
return(p)
}, error = function(e) {
cat("Stable beeswarm plot generation failed:", e$message, "\n")
return(NULL)
})
}
# 17. 滞后期特征重要性分析函数(新增)
analyze_lagged_feature_importance <- function(shap_analysis, xgb_results, output_dir) {
cat("Analyzing lagged feature importance patterns...\n")
lag_dir <- file.path(output_dir, "Lagged_Feature_Analysis")
dir.create(lag_dir, showWarnings = FALSE, recursive = TRUE)
tryCatch({
# 获取特征重要性数据
importance_data <- shap_analysis$importance_percentage
# 识别滞后期特征
lag_features <- grep("_lag", importance_data$feature, value = TRUE)
immediate_features <- grep("_0$|^[^_]*$", importance_data$feature, value = TRUE)
cat(sprintf("Found %d lagged features and %d immediate features\n",
length(lag_features), length(immediate_features)))
if(length(lag_features) > 0) {
# 创建滞后期特征分析图
lag_data <- importance_data[importance_data$feature %in% lag_features, ]
# 提取滞后期信息
lag_data$lag_months <- sapply(lag_data$feature, function(x) {
if(grepl("_lag5", x)) return(5)
if(grepl("_lag3", x)) return(3)
if(grepl("_lag1", x)) return(1)
return(0)
})
lag_data$base_feature <- gsub("_lag\\d+", "", lag_data$feature)
lag_data$base_feature <- gsub("_0$", "", lag_data$base_feature)
# 创建滞后期特征重要性图
p_lag <- ggplot(lag_data, aes(x = reorder(feature, percentage_contribution),
y = percentage_contribution,
fill = as.factor(lag_months))) +
geom_col(alpha = 0.9) +
scale_fill_manual(
values = c("0" = "#1f77b4", "3" = "#ff7f0e", "5" = "#2ca02c"),
name = "Lag (months)"
) +
coord_flip() +
labs(
title = "Lagged Feature Importance Analysis",
subtitle = "Impact of time-lagged environmental factors on dust storm prediction",
x = "Features with Lag Periods",
y = "SHAP Contribution Percentage (%)"
) +
theme_minimal(base_size = 12) +
theme(
plot.title = element_text(hjust = 0.5, face = "bold", size = 14),
plot.subtitle = element_text(hjust = 0.5, size = 10),
axis.title = element_text(face = "bold", size = 12),
axis.text = element_text(size = 10, color = "black"),
legend.position = "bottom"
)
# 保存高分辨率图像
ggsave(file.path(lag_dir, "Lagged_Feature_Importance.png"),
p_lag, width = 10, height = 8, dpi = 900, bg = "white")
# 按基础特征分组分析
feature_summary <- lag_data %>%
group_by(base_feature) %>%
summarise(
max_contribution = max(percentage_contribution),
best_lag = lag_months[which.max(percentage_contribution)],
num_lags = n()
) %>%
arrange(desc(max_contribution))
# 保存分析结果
write.csv(feature_summary,
file.path(lag_dir, "Lagged_Feature_Summary.csv"),
row.names = FALSE)
write.csv(lag_data,
file.path(lag_dir, "Lagged_Feature_Detailed_Analysis.csv"),
row.names = FALSE)
# 创建滞后期对比图
if(nrow(feature_summary) > 0) {
p_comparison <- ggplot(feature_summary,
aes(x = reorder(base_feature, max_contribution),
y = max_contribution,
fill = as.factor(best_lag))) +
geom_col(alpha = 0.9) +
scale_fill_manual(
values = c("0" = "#1f77b4", "3" = "#ff7f0e", "5" = "#2ca02c"),
name = "Optimal Lag (months)"
) +
coord_flip() +
labs(
title = "Optimal Lag Periods for Environmental Factors",
subtitle = "Best performing lag period for each environmental variable",
x = "Environmental Factors",
y = "Maximum SHAP Contribution (%)"
) +
theme_minimal(base_size = 12) +
theme(
plot.title = element_text(hjust = 0.5, face = "bold", size = 14),
plot.subtitle = element_text(hjust = 0.5, size = 10),
axis.title = element_text(face = "bold", size = 12),
axis.text = element_text(size = 10, color = "black"),
legend.position = "bottom"
)
ggsave(file.path(lag_dir, "Optimal_Lag_Periods.png"),
p_comparison, width = 10, height = 8, dpi = 900, bg = "white")
}
cat("Lagged feature analysis completed successfully\n")
return(list(
lag_plot = p_lag,
comparison_plot = if(exists("p_comparison")) p_comparison else NULL,
summary = feature_summary
))
} else {
cat("No lagged features found for analysis\n")
return(NULL)
}
}, error = function(e) {
cat("Lagged feature analysis failed:", e$message, "\n")
return(NULL)
})
}
# 18. 修复版高分辨率SHAP分析函数
create_highres_shap_analysis <- function(xgb_result, output_dir) {
cat("Creating high-resolution SHAP analysis...\n")
shap_dir <- file.path(output_dir, "HighRes_SHAP_Results")
dir.create(shap_dir, showWarnings = FALSE, recursive = TRUE)
tryCatch({
# 计算SHAP值
shap_values <- predict(
xgb_result$model,
xgb_result$X_train,
predcontrib = TRUE,
approxcontrib = FALSE
)
# 提取特征贡献部分
if (ncol(shap_values) == length(xgb_result$features) + 1) {
shap_matrix <- shap_values[, 1:length(xgb_result$features), drop = FALSE]
} else {
shap_matrix <- shap_values
}
# 确保列名正确
if (ncol(shap_matrix) == length(xgb_result$features)) {
colnames(shap_matrix) <- xgb_result$features
}
# 计算特征重要性
feature_importance_percentage <- data.frame(
feature = colnames(shap_matrix),
mean_abs_shap = colMeans(abs(shap_matrix)),
stringsAsFactors = FALSE
) %>%
mutate(
percentage_contribution = mean_abs_shap / sum(mean_abs_shap) * 100,
rank = rank(-percentage_contribution)
) %>%
arrange(desc(percentage_contribution))
# 创建高分辨率特征重要性图
create_highres_importance_plot(feature_importance_percentage, shap_dir)
# 使用稳定版蜂群图函数
beeswarm_plot <- create_stable_beeswarm_plot(
shap_matrix, xgb_result$X_train, feature_importance_percentage, output_dir)
# 保存重要性结果
write.csv(feature_importance_percentage,
file.path(shap_dir, "Feature_Importance_Percentage.csv"),
row.names = FALSE)
cat("高分辨率SHAP分析完成\n")
return(list(
importance_percentage = feature_importance_percentage,
shap_matrix = shap_matrix,
beeswarm_plot = beeswarm_plot
))
}, error = function(e) {
cat("高分辨率SHAP分析错误:", e$message, "\n")
return(NULL)
})
}
# 19. 高分辨率特征重要性图函数
create_highres_importance_plot <- function(importance_data, output_dir) {
cat("Creating high-resolution feature importance plot...\n")
# 取前15个最重要的特征
plot_data <- head(importance_data, 15)
# 创建高分辨率特征重要性图
p <- ggplot(plot_data, aes(x = percentage_contribution,
y = reorder(feature, percentage_contribution))) +
geom_col(fill = "#3498db", alpha = 0.9, width = 0.8) +
geom_text(aes(label = sprintf("%.1f%%", percentage_contribution)),
hjust = -0.1, size = 5, color = "#2c3e50", fontface = "bold") +
scale_x_continuous(expand = expansion(mult = c(0, 0.1))) +
labs(
title = "Feature Contribution Analysis for Dust Storms",
subtitle = "Feature importance ranking based on SHAP values",
x = "Contribution Percentage (%)",
y = "Environmental Drivers"
) +
theme_minimal(base_size = 14) +
theme(
plot.title = element_text(hjust = 0.5, face = "bold", size = 16),
plot.subtitle = element_text(hjust = 0.5, size = 12),
axis.title = element_text(face = "bold", size = 14),
axis.text = element_text(size = 12, color = "black"),
axis.text.y = element_text(face = "bold"),
plot.background = element_rect(fill = "white", color = NA),
panel.grid.major = element_line(color = "grey90", linewidth = 0.2),
panel.grid.minor = element_line(color = "grey95", linewidth = 0.1)
)
# 保存高分辨率图像
ggsave(file.path(output_dir, "HighRes_Feature_Importance.png"),
p, width = 12, height = 8, dpi = 900, bg = "white")
# 同时保存PDF版本
ggsave(file.path(output_dir, "HighRes_Feature_Importance.pdf"),
p, width = 12, height = 8, device = cairo_pdf, bg = "white")
cat("高分辨率特征重要性图已保存\n")
}
# 20. 高分辨率ROC曲线函数
create_highres_roc_curve <- function(xgb_result, output_dir) {
cat("Creating high-resolution ROC curve...\n")
roc_dir <- file.path(output_dir, "HighRes_ROC_Curves")
dir.create(roc_dir, showWarnings = FALSE, recursive = TRUE)
tryCatch({
# 检查响应变量是否有两个水平
if(length(unique(xgb_result$y_train)) == 2 && length(unique(xgb_result$y_test)) == 2) {
# 训练集ROC
roc_train <- pROC::roc(
response = xgb_result$y_train,
predictor = xgb_result$train_pred_prob,
levels = c(0, 1),
direction = "<"
)
# 测试集ROC
roc_test <- pROC::roc(
response = xgb_result$y_test,
predictor = xgb_result$test_pred_prob,
levels = c(0, 1),
direction = "<"
)
# 计算AUC值
train_auc <- round(roc_train$auc, 3)
test_auc <- round(roc_test$auc, 3)
# 创建ROC数据
roc_data <- rbind(
data.frame(
fpr = 1 - roc_train$specificities,
tpr = roc_train$sensitivities,
dataset = "Training Set"
),
data.frame(
fpr = 1 - roc_test$specificities,
tpr = roc_test$sensitivities,
dataset = "Test Set"
)
)
# 创建高分辨率ROC曲线
p_roc <- ggplot(roc_data, aes(x = fpr, y = tpr, color = dataset)) +
geom_line(size = 1.2) +
geom_abline(slope = 1, intercept = 0, linetype = "dashed",
color = "gray50", size = 0.8) +
scale_color_manual(
values = c("Training Set" = "#377EB8", "Test Set" = "#E41A1C"),
labels = c(
paste0("Training Set (AUC = ", train_auc, ")"),
paste0("Test Set (AUC = ", test_auc, ")")
)
) +
coord_equal() +
theme_bw(base_size = 14) +
labs(
x = "False Positive Rate (1 - Specificity)",
y = "True Positive Rate (Sensitivity)",
color = "Dataset"
) +
theme(
axis.title = element_text(face = "bold", size = 14),
axis.text = element_text(size = 12, color = "black"),
legend.title = element_text(face = "bold", size = 12),
legend.text = element_text(size = 11),
legend.position = c(0.7, 0.3),
legend.background = element_rect(fill = "white", color = "black"),
panel.grid.major = element_line(color = "grey90", linewidth = 0.3),
panel.grid.minor = element_line(color = "grey95", linewidth = 0.1),
panel.border = element_rect(color = "black", fill = NA, linewidth = 0.8)
) +
scale_x_continuous(expand = expansion(mult = c(0, 0.02))) +
scale_y_continuous(expand = expansion(mult = c(0, 0.02)))
# 保存高分辨率图像
ggsave(file.path(roc_dir, "HighRes_ROC_Curve.png"),
p_roc, width = 8, height = 7, dpi = 900, bg = "white")
# 同时保存PDF版本用于论文
ggsave(file.path(roc_dir, "HighRes_ROC_Curve.pdf"),
p_roc, width = 8, height = 7, device = cairo_pdf, bg = "white")
# 创建AUC比较表
auc_comparison <- data.frame(
Dataset = c("Training Set", "Test Set"),
AUC = c(train_auc, test_auc),
Samples = c(length(xgb_result$y_train), length(xgb_result$y_test))
)
write.csv(auc_comparison,
file.path(roc_dir, "AUC_Comparison.csv"),
row.names = FALSE)
cat("高分辨率ROC曲线已保存\n")
cat(sprintf("训练集AUC: %.3f\n", train_auc))
cat(sprintf("测试集AUC: %.3f\n", test_auc))
return(p_roc)
} else {
cat("响应变量只有一个水平,跳过ROC曲线绘制\n")
return(NULL)
}
}, error = function(e) {
cat("创建高分辨率ROC曲线时出错:", e$message, "\n")
return(NULL)
})
}
# 21. 高分辨率模型诊断函数
create_highres_classification_diagnostics <- function(xgb_result, output_dir) {
cat("Creating high-resolution classification model diagnostics...\n")
diagnostics_dir <- file.path(output_dir, "HighRes_Model_Diagnostics")
dir.create(diagnostics_dir, showWarnings = FALSE, recursive = TRUE)
plots <- list()
# 1. 高分辨率ROC曲线
roc_plot <- create_highres_roc_curve(xgb_result, diagnostics_dir)
if(!is.null(roc_plot)) {
plots$roc_plot <- roc_plot
}
return(plots)
}
# 22. 包检查函数
check_and_install_packages <- function() {
cat("Checking and installing required packages...\n")
required_packages <- c('xgboost', 'caret', 'Metrics', 'tibble', 'dplyr',
'ggplot2', 'pROC', 'viridis', 'ggpubr', 'ggbeeswarm',
'terra', 'tidyverse', 'lubridate', 'patchwork', 'rsample',
'pROC', 'MLmetrics', 'randomForest', 'sf')
for(pkg in required_packages) {
if(!require(pkg, character.only = TRUE, quietly = TRUE)) {
install.packages(pkg, dependencies = TRUE)
library(pkg, character.only = TRUE)
cat("Installed and loaded:", pkg, "\n")
} else {
cat("Package already installed:", pkg, "\n")
}
}
}
# 23. 综合报告生成函数
generate_comprehensive_report <- function(xgb_results, shap_results, data_info, output_dir) {
cat("Generating comprehensive analysis report...\n")
report_dir <- file.path(output_dir, "Comprehensive_Report")
dir.create(report_dir, showWarnings = FALSE, recursive = TRUE)
# 创建简单的文本报告
report_content <- paste(
"Dust Storm Probability XGBoost+SHAP Analysis Report",
paste("Generated time:", Sys.time()),
paste("Data processing level:", data_info$data_type),
paste("Total samples:", data_info$sample_count),
paste("Number of features:", data_info$feature_count),
"",
"Model Performance:",
paste("Training Accuracy:", round(xgb_results$metrics$train_accuracy, 4)),
paste("Test Accuracy:", round(xgb_results$metrics$test_accuracy, 4)),
paste("Test AUC:", ifelse(is.na(xgb_results$metrics$test_auc), "N/A",
round(xgb_results$metrics$test_auc, 4))),
"",
"Analysis completed successfully with high-resolution outputs.",
sep = "\n"
)
writeLines(report_content, file.path(report_dir, "Analysis_Report.txt"))
cat("Comprehensive report generated:", report_dir, "\n")
}
# 24. 主分析函数 - 修复版
main_dust_storm_analysis_highres <- function(output_base = "I:/DustStorm_Analysis/XGBoost_SHAP_HighRes",
processing_level = "pixel",
max_pixels = 100,
include_lag_features = TRUE,
feature_set = "hybrid_optimal") {
# 检查并安装必要包
check_and_install_packages()
cat("Starting improved dust storm probability XGBoost analysis (high-resolution version)...\n")
cat("Data processing level:", processing_level, "\n")
cat("Output directory:", output_base, "\n")
cat("Include lagged features:", include_lag_features, "\n")
cat("Feature set used:", feature_set, "\n")
# 创建输出目录
dir.create(output_base, showWarnings = FALSE, recursive = TRUE)
# 设置terra临时目录
temp_dir <- file.path(output_base, "temp_terra")
dir.create(temp_dir, showWarnings = FALSE)
terraOptions(tempdir = temp_dir)
cat("terra temporary directory set to:", temp_dir, "\n")
tryCatch({
# 1. 数据读取
cat("=== Step 1: Reading data ===\n")
raw_data <- read_all_dust_data_annual()
if(is.null(raw_data) || length(raw_data$data) == 0) {
stop("Failed to read any data, please check file paths")
}
# 2. 统一分辨率
cat("=== Step 2: Resampling to common resolution ===\n")
processed_data <- resample_to_common_resolution(raw_data$data)
# 3. 数据预处理(根据选择的级别)
cat("=== Step 3: Preprocessing data ===\n")
if(processing_level == "pixel") {
analysis_data <- preprocess_pixel_level_data(processed_data, raw_data$time_index,
max_pixels_per_month = max_pixels)
data_type <- "pixel_level"
} else {
analysis_data <- preprocess_regional_data(processed_data, raw_data$time_index)
data_type <- "regional_level"
}
if(is.null(analysis_data) || nrow(analysis_data) == 0) {
stop("Data preprocessing failed, cannot continue analysis")
}
cat("Original sample size:", nrow(analysis_data), "\n")
cat("Data years:", paste(sort(unique(analysis_data$year)), collapse = ", "), "\n")
# 4. 构建滞后期特征
if(include_lag_features) {
cat("=== Step 4: Creating lagged features ===\n")
analysis_data <- create_lagged_features(analysis_data)
cat("Lagged features created successfully\n")
cat("Available features:", paste(setdiff(colnames(analysis_data),
c("year", "month", "date", "pixel_id", "x", "y")),
collapse = ", "), "\n")
}
# 5. 创建沙尘暴分类标签
cat("=== Step 5: Creating dust storm labels ===\n")
analysis_data <- create_dust_storm_labels(analysis_data,
dcmd_threshold_method = "quantile",
threshold_value = 0.75)
# 6. 划分训练集和测试集
cat("=== Step 6: Splitting data into training and test sets ===\n")
split_data <- split_train_test_data(analysis_data, test_size = 0.3)
# 检查划分结果
if(is.null(split_data$train_data) || is.null(split_data$test_data)) {
stop("Data splitting failed")
}
cat(sprintf("Training set samples: %d\n", nrow(split_data$train_data)))
cat(sprintf("Test set samples: %d\n", nrow(split_data$test_data)))
cat(sprintf("Training set dust storm ratio: %.2f%%\n",
mean(split_data$train_data$dust_storm) * 100))
cat(sprintf("Test set dust storm ratio: %.2f%%\n",
mean(split_data$test_data$dust_storm) * 100))
# 7. XGBoost分类分析
cat("=== Step 7: Performing XGBoost classification ===\n")
if(feature_set == "compare_all") {
# 比较所有特征集
comparison_results <- compare_feature_sets(split_data$train_data, split_data$test_data)
xgb_results <- comparison_results$results$hybrid_optimal
} else {
# 使用指定特征集
xgb_results <- perform_xgboost_classification_improved(
split_data$train_data, split_data$test_data,
feature_set_name = feature_set
)
}
# 8. 高分辨率SHAP分析
cat("=== Step 8: Performing SHAP analysis ===\n")
shap_analysis <- create_highres_shap_analysis(xgb_results, output_base)
# 9. 滞后期特征分析
if(include_lag_features && !is.null(shap_analysis)) {
cat("=== Step 9: Analyzing lagged feature importance ===\n")
analyze_lagged_feature_importance(shap_analysis, xgb_results, output_base)
}
# 10. 高分辨率模型诊断
cat("=== Step 10: Creating model diagnostics ===\n")
diagnostic_plots <- create_highres_classification_diagnostics(xgb_results, output_base)
# 11. 生成综合报告
cat("=== Step 11: Generating comprehensive report ===\n")
data_info <- list(
sample_count = nrow(analysis_data),
feature_count = length(setdiff(colnames(analysis_data),
c("year", "month", "date", "pixel_id", "x", "y", "dust_storm"))),
data_type = data_type
)
generate_comprehensive_report(xgb_results, shap_analysis, data_info, output_base)
# 完成输出
separator_line <- paste(rep("=", 80), collapse = "")
cat("\n", separator_line, "\n", sep = "")
cat("=== Improved Dust Storm Probability XGBoost+SHAP Analysis Completed (High-Resolution) ===\n")
cat(separator_line, "\n")
cat("Output directory:", output_base, "\n")
cat("Data processing level:", processing_level, "\n")
cat("Total sample size:", nrow(analysis_data), "\n")
cat("Training set samples:", nrow(split_data$train_data), "\n")
cat("Test set samples:", nrow(split_data$test_data), "\n")
cat("Lagged features included:", include_lag_features, "\n")
cat("Feature set used:", feature_set, "\n")
cat("Main improvements:\n")
cat("- Complete data processing pipeline added\n")
cat("- High-resolution output (900 DPI) with English labels\n")
cat("- Journal-style formatting for all figures\n")
cat("- All known errors fixed\n")
return(list(
xgb_results = xgb_results,
shap_analysis = shap_analysis,
diagnostics = diagnostic_plots,
analysis_data = analysis_data,
split_data = split_data,
processing_level = processing_level,
include_lag_features = include_lag_features
))
}, error = function(e) {
cat("\n=== Analysis Failed ===\n")
cat("Error message:", e$message, "\n")
cat("Traceback:\n")
traceback()
return(NULL)
})
}
# 执行分析
cat("High-resolution dust storm XGBoost+SHAP analysis script loaded\n")
cat("Run the following command to start analysis:\n")
cat("results_highres <- main_dust_storm_analysis_highres(processing_level = 'pixel', max_pixels = 100, include_lag_features = TRUE)\n")
# 开始分析
cat("\nStarting high-resolution analysis...\n")
results_highres <- main_dust_storm_analysis_highres(
output_base = "I:/DustStorm_Analysis/XGBoost_SHAP_HighRes",
processing_level = "pixel",
max_pixels = 100,
include_lag_features = TRUE,
feature_set = "hybrid_optimal"
)运行后报错,请给出修改后的完整代码:=== Step 8: Performing SHAP analysis ===
Creating high-resolution SHAP analysis...
Creating high-resolution feature importance plot...
高分辨率特征重要性图已保存
Creating stable high-resolution beeswarm plot using standard ggplot...
Stable beeswarm plot generation failed: 缺少参数"observed",也缺失默认值