[R] Bind element of List of matrix or data.frame or list

本文探讨了如何使用 R 语言将矩阵或数据框列表中的元素行进行拼接的方法,包括直接使用 do.call(rbind,)、plyr 库的 ldply 函数、data.table 库的 rbindlist 函数等,并对比了它们的效率。

1. Alternative solutions for list of matrix or data.frame

If I have a list of matrix or data.frame, we can use the following ways to bind the rows of all elements.

Firstly, I generate toy data

myList1 <- list(matrix(rnorm(2*3), ncol=2),
                matrix(rnorm(2*3), ncol=2),
                matrix(rnorm(2*3), ncol=2),
                matrix(rnorm(2*3), ncol=2))
myList2 <- list(as.data.frame(matrix(rnorm(2*3), ncol=2)),
                as.data.frame(matrix(rnorm(2*3), ncol=2)),
                as.data.frame(matrix(rnorm(2*3), ncol=2)),
                as.data.frame(matrix(rnorm(2*3), ncol=2)))

Now I list the alternative solutions

# solution 1
result1.1 <- do.call(rbind, myList1)
head(result1.1)
result1.2 <- do.call(rbind, myList2)
head(result1.2)

# solution 2
## plyr: the split-apply-combine paradigm for R
library(plyr)
result1.2.1 <- ldply(myList1, rbind)
head(result1.2.1)
###Error: All inputs to rbind.fill must be data.frames
result1.2.2 <- rbind.fill(myList1)
head(result1.2.2)
result2.2.1 <- ldply(myList2, rbind)
head(result2.2.1)
result2.2.2 <- rbind.fill(myList2)
head(result2.2.2)

# solution 3
## data.table: Enhanced data.frame
library(data.table)
###Error in rbindlist(myList1) : Item 1 of list input is not a data.frame, data.table or list
result1.3 <- rbindlist(myList1)
head(result1.3)
result2.3 <- rbindlist(myList2)
head(result2.3)

From the codes, we can see that

  • for matrix, only do.call(rbind, ), ldply can work.
  • for data.frame, all do.call(rbind, ), ldply, rbind.fill, rbindlist can work.

And now benchmark for all solutions

# benchmark
## benchmark: a simple wrapper around system.time
library(rbenchmark)
benchmark(do.call(rbind, myList2), ldply(myList2, rbind), rbind.fill(myList2), rbindlist(myList2))

2. Alternative solutions for list of list

#############list of list#############################

# generate list of list
myList3 <- vector("list", 4)
for(i in 1:4){
  myList3[[i]] <- vector("list", 2)
  for(j in 1:2){
    myList3[[i]][[j]] <- rnorm(3)
  }
}

# bind each element of outer list
tempList <- lapply(myList3, function(z)do.call(rbind,z))
## selection operator "["
bind.ith.rows <- function(i) do.call(rbind, lapply(tempList, "[", i, TRUE))
nr <- nrow(tempList[[1]])
lapply(1:nr, bind.ith.rows)
# 修复版沙尘暴XGBoost+SHAP分析 - 高分辨率英文版 # 主要改进: # 1. 所有图表输出为900 DPI高分辨率 # 2. 所有文字标签改为英文 # 3. 按照1区论文格式优化图表样式 # 4. 包含滞后期特征分析 # 5. 修复SHAP蜂群图参数问题 # ================================================== # 1. 安装并加载所需包 cat("Installing and loading required packages...\n") required_packages <- c('xgboost', 'caret', 'Metrics', 'tibble', 'dplyr', 'ggplot2', 'pROC', 'viridis', 'ggpubr', 'ggbeeswarm', 'terra', 'tidyverse', 'lubridate', 'patchwork', 'rsample', 'pROC', 'MLmetrics', 'randomForest', 'sf') for(pkg in required_packages) { if(!require(pkg, character.only = TRUE, quietly = TRUE)) { install.packages(pkg) library(pkg, character.only = TRUE) } } # 2. 设置工作目录到I盘 setwd("I:/") cat("Working directory set to:", getwd(), "\n") # 3. 统一分辨率函数 resample_to_common_resolution <- function(data_list) { cat("Resampling data to common resolution...\n") # 使用DCMD数据作为模板(50km分辨率) if("dcmd" %in% names(data_list)) { template <- data_list$dcmd[[1]] cat("Using DCMD data as resolution template (50km)\n") } else { # 如果没有DCMD,使用第一个可用变量 first_var <- names(data_list)[1] template <- data_list[[first_var]][[1]] cat("Using", first_var, "data as resolution template\n") } cat("Target resolution:", round(res(template)[1] * 111, 2), "km\n") # 重采样各变量到统一分辨率 resampled_data <- list() for(var_name in names(data_list)) { cat("Processing variable:", var_name, "\n") var_data <- data_list[[var_name]] if(is.null(var_data)) next tryCatch({ # 如果已经是DCMD数据,直接使用 if(var_name == "dcmd") { resampled_data[[var_name]] <- var_data cat(" -> Keeping original resolution\n") } else { # 其他变量重采样到统一分辨率 resampled_data[[var_name]] <- resample(var_data, template, method = "bilinear") cat(" -> Resampling completed\n") } }, error = function(e) { cat(" -> Resampling failed:", e$message, "\n") }) } return(resampled_data) } # 4. 数据读取函数 read_all_dust_data_annual <- function() { cat("Reading annual dust and related environmental data...\n") # 定义时间序列 (2005-2024) time_index <- seq(as.Date("2005-01-01"), as.Date("2024-12-01"), by = "month") years <- year(time_index) months <- sprintf("%02d", month(time_index)) # 初始化数据存储 all_data <- list() # 读取函数 read_raster_data <- function(file_pattern, var_name) { files <- file_pattern existing_files <- files[file.exists(files)] if(length(existing_files) == 0) { cat("Warning: No files found for", var_name, "\n") return(NULL) } sorted_files <- sort(existing_files) raster_data <- tryCatch({ rast(sorted_files) }, error = function(e) { cat("Reading", var_name, "failed:", e$message, "\n") return(NULL) }) return(raster_data) } # 读取沙尘柱质量密度数据 (目标变量) cat("Reading dust column mass density data...\n") dcmd_files <- file.path("H:/相关数据/2005-2024年北方风沙带MERRA2沙尘柱质量密度", paste0("masked_DCMD_", years, "_", months, ".tif")) all_data$dcmd <- read_raster_data(dcmd_files, "dcmd") # 读取其他变量数据 cat("Reading precipitation data...\n") precip_files <- file.path("H:/相关数据/2005-2024北方风沙带月尺度降水", paste0("masked_ERA5Land_MonthlyPrecip_", years, months, ".tif")) all_data$precip <- read_raster_data(precip_files, "precip") cat("Reading strong wind hours data...\n") wind_files <- file.path("H:/相关数据/2005-2024年北方风沙带月尺度大风总小时数", paste0("masked_high_wind_hours_", years, "_", months, ".tif")) all_data$wind_hours <- read_raster_data(wind_files, "wind_hours") cat("Reading temperature data...\n") temp_files <- file.path("H:/相关数据/2005-2024北方风沙带2m气温", paste0("masked_ERA5Land_MeanTemp_", years, months, ".tif")) all_data$temp <- read_raster_data(temp_files, "temp") cat("Reading LST data...\n") lst_files <- file.path("H:/相关数据/2005-2024年月尺度北方风沙带LST", paste0("masked_LST_MONTHLY_", years, "_", months, ".tif")) all_data$lst <- read_raster_data(lst_files, "lst") cat("Reading NDVI data...\n") ndvi_files <- file.path("H:/相关数据/2005-2024北方风沙带月尺度NDVI", paste0("masked_NDVI_", years, "-", months, ".tif")) all_data$ndvi <- read_raster_data(ndvi_files, "ndvi") cat("Reading LAI data...\n") lai_files <- file.path("H:/相关数据/2005-2024年月尺度北方风沙带LAI", paste0("masked_LAI_MONTHLY_", years, "_", months, ".tif")) all_data$lai <- read_raster_data(lai_files, "lai") cat("Reading soil moisture data...\n") sm_files <- file.path("H:/相关数据/2005-2024年月尺度北方风沙带根区土壤湿度SMRoot-GLEAM4", paste0("masked_", years, "_", months, ".tif")) all_data$sm <- read_raster_data(sm_files, "sm") cat("Reading evapotranspiration data...\n") et_files <- file.path("H:/相关数据/2005-2024年月尺度北方风沙带ET-GLEAM4", paste0("masked_", years, "_", months, ".tif")) all_data$et <- read_raster_data(et_files, "et") cat("Reading VHI data...\n") vhi_files <- file.path("H:/论文/沙尘暴特征/data/VHI计算", paste0("VHI_", years, "_", months, ".tif")) all_data$vhi <- read_raster_data(vhi_files, "vhi") # 移除为NULL的数据 all_data <- all_data[!sapply(all_data, is.null)] cat("Successfully read", length(all_data), "variables\n") cat("Variables include:", paste(names(all_data), collapse = ", "), "\n") return(list(data = all_data, time_index = time_index)) } # 5. 区域级数据预处理函数 preprocess_regional_data <- function(data_list, time_index) { cat("Preprocessing regional-level data...\n") # 计算区域月平均值 regional_monthly <- data.frame() for(var_name in names(data_list)) { var_data <- data_list[[var_name]] if(is.null(var_data)) next # 按月份计算区域平均值 for(i in 1:nlyr(var_data)) { if(i > length(time_index)) break monthly_mean <- tryCatch({ global(var_data[[i]], mean, na.rm = TRUE)[1,1] }, error = function(e) { NA }) if(!is.na(monthly_mean)) { regional_monthly <- rbind(regional_monthly, data.frame( year = year(time_index[i]), month = month(time_index[i]), date = time_index[i], variable = var_name, value = monthly_mean )) } } } # 转换为宽格式 wide_data <- regional_monthly %>% pivot_wider(names_from = variable, values_from = value) %>% arrange(date) # 移除包含NA的行 wide_data_clean <- wide_data[complete.cases(wide_data), ] cat("Regional data preprocessing completed, dimensions:", dim(wide_data_clean), "\n") cat("Time range:", min(wide_data_clean$date), "to", max(wide_data_clean$date), "\n") cat("Available variables:", paste(setdiff(colnames(wide_data_clean), c("year", "month", "date")), collapse = ", "), "\n") return(wide_data_clean) } # 6. 修复的像元级数据预处理函数 preprocess_pixel_level_data <- function(data_list, time_index, max_pixels_per_month = 100) { cat("Processing pixel-level data...\n") # 初始化存储 pixel_data_list <- list() # 获取有效像元掩模(基于第一个变量) first_var <- names(data_list)[1] template <- data_list[[first_var]][[1]] # 安全地获取坐标 cat("Getting spatial coordinates...\n") tryCatch({ # 获取模板的像元总数 total_cells <- ncell(template) cat("Total template cells:", total_cells, "\n") # 获取有效像元(非NA) cat("Identifying valid cells...\n") valid_cells <- which(!is.na(values(template[[1]]))) cat("Number of valid cells:", length(valid_cells), "\n") if(length(valid_cells) == 0) { cat("Warning: No valid cells found\n") return(NULL) } # 随机选择部分像元以避免内存问题 set.seed(123) sampled_cells <- sample(valid_cells, min(max_pixels_per_month, length(valid_cells))) cat("Sampled cells count:", length(sampled_cells), "\n") # 使用xyFromCell安全获取坐标 sampled_coords <- xyFromCell(template, sampled_cells) cat("Sampled coordinates dimension:", dim(sampled_coords), "\n") }, error = function(e) { cat("Failed to get coordinates:", e$message, "\n") return(NULL) }) # 处理所有时间步,但限制数量以避免内存问题 total_time_steps <- length(time_index) max_time_steps <- min(240, total_time_steps) # 先处理前240个时间步(5年数据) cat(sprintf("Processing %d/%d time steps...\n", max_time_steps, total_time_steps)) # 提取每个时间步的像元数据 for(i in 1:max_time_steps) { current_date <- time_index[i] current_year <- year(current_date) current_month <- month(current_date) cat(sprintf("Processing time step %d/%d: %s (Year: %d, Month: %d)\n", i, max_time_steps, current_date, current_year, current_month)) monthly_pixels <- data.frame( pixel_id = sampled_cells, year = current_year, month = current_month, date = current_date, x = sampled_coords[, 1], y = sampled_coords[, 2] ) # 提取各变量值 for(var_name in names(data_list)) { var_data <- data_list[[var_name]] if(!is.null(var_data) && nlyr(var_data) >= i) { tryCatch({ # 使用extract安全提取值 extracted_values <- terra::extract(var_data[[i]], sampled_cells) if(ncol(extracted_values) == 2) { values <- extracted_values[, 2] } else { values <- extracted_values[, 1] } monthly_pixels[[var_name]] <- values }, error = function(e) { cat(" Failed to extract variable", var_name, ":", e$message, "\n") monthly_pixels[[var_name]] <- NA }) } else { monthly_pixels[[var_name]] <- NA cat(" Variable", var_name, "data insufficient or NULL\n") } } pixel_data_list[[i]] <- monthly_pixels # 每处理5个时间步就清理内存 if(i %% 5 == 0) { gc() cat(" Memory cleanup completed\n") } } # 合并所有时间步数据 pixel_data <- bind_rows(pixel_data_list) # 移除包含NA的行 original_rows <- nrow(pixel_data) pixel_data_clean <- pixel_data[complete.cases(pixel_data), ] cat("Pixel-level data preprocessing completed\n") cat("Original data rows:", original_rows, "\n") cat("Cleaned data rows:", nrow(pixel_data_clean), "\n") cat("Total samples:", nrow(pixel_data_clean), "\n") cat("Time range:", min(pixel_data_clean$date), "to", max(pixel_data_clean$date), "\n") cat("Years included:", paste(sort(unique(pixel_data_clean$year)), collapse = ", "), "\n") return(pixel_data_clean) } # 7. 滞后期特征构建函数 create_lagged_features <- function(analysis_data, lag_months = c(0, 5)) { cat("Creating lagged features...\n") # 按像元分组处理时间序列 lagged_data <- analysis_data %>% arrange(pixel_id, date) %>% group_by(pixel_id) %>% # 为生态水文因子创建滞后期 mutate( # 即时因子(保持原样) wind_hours_0 = wind_hours, vhi_0 = vhi, # 关键滞后期生态水文因子(5个月) precip_lag5 = lag(precip, 5), temp_lag5 = lag(temp, 5), ndvi_lag5 = lag(ndvi, 5), lai_lag5 = lag(lai, 5), lst_lag5 = lag(lst, 5), sm_lag5 = lag(sm, 5), et_lag5 = lag(et, 5), vhi_lag5 = lag(vhi, 5), # 其他滞后期(用于对比分析) precip_lag3 = lag(precip, 3), ndvi_lag3 = lag(ndvi, 3), sm_lag3 = lag(sm, 3) ) %>% ungroup() %>% # 移除由于滞后期产生的NA行 filter(!is.na(precip_lag5) & !is.na(ndvi_lag5)) cat(sprintf("Lagged features created, samples reduced from %d to %d\n", nrow(analysis_data), nrow(lagged_data))) cat("New features:", paste(grep("_lag", colnames(lagged_data), value = TRUE), collapse = ", "), "\n") return(lagged_data) } # 8. 创建沙尘暴分类标签 create_dust_storm_labels <- function(data, dcmd_threshold_method = "quantile", threshold_value = 0.75) { cat("Creating dust storm classification labels...\n") dcmd_values <- data$dcmd if(dcmd_threshold_method == "quantile") { threshold <- quantile(dcmd_values, probs = threshold_value, na.rm = TRUE) data$dust_storm <- ifelse(data$dcmd >= threshold, 1, 0) cat(sprintf("Using quantile threshold: %.4f (%.0f%% quantile)\n", threshold, threshold_value * 100)) } else if(dcmd_threshold_method == "sd") { mean_val <- mean(dcmd_values, na.rm = TRUE) sd_val <- sd(dcmd_values, na.rm = TRUE) threshold <- mean_val + sd_val data$dust_storm <- ifelse(data$dcmd >= threshold, 1, 0) cat(sprintf("Using standard deviation threshold: %.4f (mean + 1SD)\n", threshold)) } else { data$dust_storm <- ifelse(data$dcmd >= threshold_value, 1, 0) cat(sprintf("Using fixed threshold: %.4f\n", threshold_value)) } # 统计分类结果 storm_count <- sum(data$dust_storm) non_storm_count <- sum(data$dust_storm == 0) total_count <- nrow(data) cat(sprintf("Dust storm samples: %d (%.1f%%)\n", storm_count, storm_count/total_count*100)) cat(sprintf("Non-dust storm samples: %d (%.1f%%)\n", non_storm_count, non_storm_count/total_count*100)) return(data) } # 9. 使用随机7:3划分训练集和测试集 split_train_test_data <- function(data, test_size = 0.3, target_var = "dust_storm") { cat("Splitting data into training and test sets (7:3)...\n") set.seed(123) # 确保目标变量存在 if(!target_var %in% colnames(data)) { stop("Target variable does not exist: ", target_var) } # 使用分层抽样确保两个集合都有沙尘暴样本 train_idx <- createDataPartition(data[[target_var]], p = 1 - test_size, list = FALSE) train_data <- data[train_idx, ] test_data <- data[-train_idx, ] # 检查划分结果 cat(sprintf("Training set: %d samples\n", nrow(train_data))) cat(sprintf("Test set: %d samples\n", nrow(test_data))) # 检查类别分布 train_storm_ratio <- mean(train_data[[target_var]]) test_storm_ratio <- mean(test_data[[target_var]]) cat(sprintf("Training set dust storm ratio: %.2f%%\n", train_storm_ratio * 100)) cat(sprintf("Test set dust storm ratio: %.2f%%\n", test_storm_ratio * 100)) # 如果测试集没有沙尘暴样本,重新划分 if(test_storm_ratio == 0) { cat("Warning: Test set has no dust storm samples, re-splitting...\n") return(split_train_test_data_manual(data, test_size, target_var)) } return(list( train_data = train_data, test_data = test_data )) } # 10. 手动划分函数 - 确保测试集有沙尘暴样本 split_train_test_data_manual <- function(data, test_size = 0.3, target_var = "dust_storm") { cat("Using manual splitting to ensure test set diversity...\n") set.seed(123) # 分离沙尘暴和非沙尘暴样本 storm_data <- data[data[[target_var]] == 1, ] non_storm_data <- data[data[[target_var]] == 0, ] # 计算测试集需要的样本数 n_test_storm <- max(1, round(nrow(storm_data) * test_size)) n_test_non_storm <- round(nrow(non_storm_data) * test_size) # 从沙尘暴样本中随机选择测试样本 test_storm_idx <- sample(1:nrow(storm_data), n_test_storm) test_non_storm_idx <- sample(1:nrow(non_storm_data), n_test_non_storm) # 创建测试集 test_data <- rbind( storm_data[test_storm_idx, ], non_storm_data[test_non_storm_idx, ] ) # 创建训练集 train_data <- rbind( storm_data[-test_storm_idx, ], non_storm_data[-test_non_storm_idx, ] ) # 打乱顺序 train_data <- train_data[sample(1:nrow(train_data)), ] test_data <- test_data[sample(1:nrow(test_data)), ] # 检查结果 cat(sprintf("Training set: %d samples (dust storms: %.2f%%)\n", nrow(train_data), mean(train_data[[target_var]]) * 100)) cat(sprintf("Test set: %d samples (dust storms: %.2f%%)\n", nrow(test_data), mean(test_data[[target_var]]) * 100)) return(list( train_data = train_data, test_data = test_data )) } # 11. 分类评估指标计算 calculate_classification_metrics <- function(y_train, y_test, train_pred_class, test_pred_class, train_pred_prob, test_pred_prob) { # 训练集指标 train_accuracy <- sum(train_pred_class == y_train) / length(y_train) # 检查是否有两个类别 if(length(unique(y_train)) == 2) { train_precision <- precision(y_train, train_pred_class) train_recall <- recall(y_train, train_pred_class) train_f1 <- F1_Score(y_train, train_pred_class) train_auc <- auc(roc(y_train, train_pred_prob)) } else { train_precision <- NA train_recall <- NA train_f1 <- NA train_auc <- NA cat("Warning: Training set has only one class, cannot calculate precision, recall, F1 and AUC\n") } # 测试集指标 test_accuracy <- sum(test_pred_class == y_test) / length(y_test) if(length(unique(y_test)) == 2) { test_precision <- precision(y_test, test_pred_class) test_recall <- recall(y_test, test_pred_class) test_f1 <- F1_Score(y_test, test_pred_class) test_auc <- auc(roc(y_test, test_pred_prob)) } else { test_precision <- NA test_recall <- NA test_f1 <- NA test_auc <- NA cat("Warning: Test set has only one class, cannot calculate precision, recall, F1 and AUC\n") } # 混淆矩阵 train_cm <- table(Actual = y_train, Predicted = train_pred_class) test_cm <- table(Actual = y_test, Predicted = test_pred_class) return(list( train_accuracy = train_accuracy, train_precision = train_precision, train_recall = train_recall, train_f1 = train_f1, train_auc = train_auc, test_accuracy = test_accuracy, test_precision = test_precision, test_recall = test_recall, test_f1 = test_f1, test_auc = test_auc, train_cm = train_cm, test_cm = test_cm )) } # 12. 打印分类结果 print_classification_results <- function(metrics) { cat("\n=== Classification Model Evaluation Results ===\n") cat("Training Set Performance:\n") cat(sprintf(" Accuracy: %.4f\n", metrics$train_accuracy)) if(!is.na(metrics$train_precision)) { cat(sprintf(" Precision: %.4f\n", metrics$train_precision)) cat(sprintf(" Recall: %.4f\n", metrics$train_recall)) cat(sprintf(" F1 Score: %.4f\n", metrics$train_f1)) cat(sprintf(" AUC: %.4f\n", metrics$train_auc)) } else { cat(" Precision: N/A (only one class)\n") cat(" Recall: N/A (only one class)\n") cat(" F1 Score: N/A (only one class)\n") cat(" AUC: N/A (only one class)\n") } cat("\nTest Set Performance:\n") cat(sprintf(" Accuracy: %.4f\n", metrics$test_accuracy)) if(!is.na(metrics$test_precision)) { cat(sprintf(" Precision: %.4f\n", metrics$test_precision)) cat(sprintf(" Recall: %.4f\n", metrics$test_recall)) cat(sprintf(" F1 Score: %.4f\n", metrics$test_f1)) cat(sprintf(" AUC: %.4f\n", metrics$test_auc)) } else { cat(" Precision: N/A (only one class)\n") cat(" Recall: N/A (only one class)\n") cat(" F1 Score: N/A (only one class)\n") cat(" AUC: N/A (only one class)\n") } cat("\nTraining Set Confusion Matrix:\n") print(metrics$train_cm) cat("\nTest Set Confusion Matrix:\n") print(metrics$test_cm) } # 13. 特征集定义函数 define_feature_sets <- function() { list( # 仅即时特征(基准模型) immediate_only = c("wind_hours", "vhi", "precip", "temp", "ndvi", "lai", "lst", "sm", "et"), # 仅滞后期特征 lagged_only = c("wind_hours_0", "vhi_0", "precip_lag5", "temp_lag5", "ndvi_lag5", "lai_lag5", "lst_lag5", "sm_lag5", "et_lag5", "vhi_lag5"), # 混合特征(推荐)- 基于滞时分析结果 hybrid_optimal = c("wind_hours_0", "vhi_0", # 即时强相关因子 "temp_lag5", "et_lag5", "ndvi_lag5", # 滞时5个月强相关 "lst_lag5", "lai_lag5", "precip_lag5", # 滞时5个月重要因子 "sm_lag5") # 滞时5个月中等相关 ) } # 14. 改进的XGBoost分类分析 perform_xgboost_classification_improved <- function(train_data, test_data, target_var = "dust_storm", feature_set_name = "hybrid_optimal") { cat("Performing improved XGBoost classification analysis...\n") cat("Using feature set:", feature_set_name, "\n") set.seed(123) # 获取特征集 feature_sets <- define_feature_sets() if(!feature_set_name %in% names(feature_sets)) { stop("Unknown feature set name: ", feature_set_name) } features <- feature_sets[[feature_set_name]] # 检查特征有效性 valid_features <- c() for(feature in features) { if(feature %in% colnames(train_data) && sd(train_data[[feature]], na.rm = TRUE) > 0) { valid_features <- c(valid_features, feature) } } if(length(valid_features) == 0) { stop("No valid features available for modeling") } cat(sprintf("Valid features: %d (original features: %d)\n", length(valid_features), length(features))) cat("Features used:", paste(valid_features, collapse = ", "), "\n") # 准备数据矩阵 X_train <- as.matrix(train_data[, valid_features]) X_test <- as.matrix(test_data[, valid_features]) y_train <- train_data[[target_var]] y_test <- test_data[[target_var]] # 数据标准化 preprocess_params <- preProcess(X_train, method = c("center", "scale")) X_train_scaled <- predict(preprocess_params, X_train) X_test_scaled <- predict(preprocess_params, X_test) # 训练XGBoost分类模型 xgb_model <- xgboost( data = X_train_scaled, label = y_train, max_depth = 3, eta = 0.05, nthread = 2, nrounds = 150, objective = "binary:logistic", eval_metric = "logloss", verbose = 0, early_stopping_rounds = 25, subsample = 0.7, colsample_bytree = 0.7, lambda = 2, alpha = 0.5, min_child_weight = 3 ) # 预测概率 train_pred_prob <- predict(xgb_model, X_train_scaled) test_pred_prob <- predict(xgb_model, X_test_scaled) # 转换为分类预测(阈值0.5) train_pred_class <- ifelse(train_pred_prob > 0.5, 1, 0) test_pred_class <- ifelse(test_pred_prob > 0.5, 1, 0) # 计算评估指标 metrics <- calculate_classification_metrics(y_train, y_test, train_pred_class, test_pred_class, train_pred_prob, test_pred_prob) # 打印评估结果 print_classification_results(metrics) return(list( model = xgb_model, features = valid_features, X_train = X_train_scaled, X_test = X_test_scaled, y_train = y_train, y_test = y_test, train_pred_prob = train_pred_prob, test_pred_prob = test_pred_prob, train_pred_class = train_pred_class, test_pred_class = test_pred_class, metrics = metrics, feature_names = valid_features, preprocess_params = preprocess_params, feature_set_name = feature_set_name )) } # 15. 比较不同特征集的性能 compare_feature_sets <- function(train_data, test_data, target_var = "dust_storm") { cat("Comparing model performance across different feature sets...\n") feature_sets <- define_feature_sets() results <- list() for(set_name in names(feature_sets)) { cat(sprintf("\n=== Training feature set: %s ===\n", set_name)) # 筛选特征 features <- feature_sets[[set_name]] available_features <- features[features %in% colnames(train_data)] if(length(available_features) > 0) { # 训练模型 xgb_result <- perform_xgboost_classification_improved( train_data, test_data, target_var, set_name ) results[[set_name]] <- list( features = available_features, metrics = xgb_result$metrics, model = xgb_result$model ) } else { cat(" No available features, skipping\n") } } # 打印比较结果 cat("\n=== Feature Set Performance Comparison ===\n") comparison_df <- data.frame() for(set_name in names(results)) { metrics <- results[[set_name]]$metrics comparison_df <- rbind(comparison_df, data.frame( Feature_Set = set_name, Train_Accuracy = metrics$train_accuracy, Test_Accuracy = metrics$test_accuracy, Test_Precision = ifelse(is.na(metrics$test_precision), NA, metrics$test_precision), Test_Recall = ifelse(is.na(metrics$test_recall), NA, metrics$test_recall), Test_F1 = ifelse(is.na(metrics$test_f1), NA, metrics$test_f1), Test_AUC = ifelse(is.na(metrics$test_auc), NA, metrics$test_auc), Num_Features = length(results[[set_name]]$features) )) } print(comparison_df) return(list( results = results, comparison = comparison_df )) } # 16. 期刊风格高分辨率蜂群图函数(最终优化版) create_stable_beeswarm_plot <- function(shap_matrix, X_train, importance_data, output_dir) { cat("Creating stable high-resolution beeswarm plot using standard ggplot...\n") tryCatch({ # 1. 数据抽样与处理 max_samples <- min(800, nrow(shap_matrix)) set.seed(123) sample_idx <- sample(1:nrow(shap_matrix), max_samples) shap_df <- as.data.frame(shap_matrix[sample_idx, , drop = FALSE]) feature_df <- as.data.frame(X_train[sample_idx, , drop = FALSE]) # 2. 确保特征列名一致 common_features <- intersect(colnames(shap_df), colnames(feature_df)) if (length(common_features) == 0) { stop("SHAP矩阵与特征矩阵无共同特征") } # 3. 生成长格式数据 plot_data <- data.frame() for (f in common_features) { if (f %in% colnames(shap_df) && f %in% colnames(feature_df)) { temp <- data.frame( feature = f, shap_value = shap_df[[f]], feature_value = feature_df[[f]], stringsAsFactors = FALSE ) plot_data <- rbind(plot_data, temp) } } # 4. 处理因子水平 valid_features <- importance_data$feature[importance_data$feature %in% common_features] if (length(valid_features) == 0) valid_features <- common_features # 确保因子水平正确设置 plot_data$feature <- factor(plot_data$feature, levels = rev(valid_features)) # 移除无效数据 plot_data <- plot_data[!is.na(plot_data$shap_value) & !is.na(plot_data$feature_value), ] if (nrow(plot_data) == 0) { stop("No valid data after filtering NA values") } # 5. 创建稳定的蜂群图 - 使用geom_point + position_jitterdodge p <- ggplot2::ggplot( plot_data, ggplot2::aes(x = shap_value, y = feature, color = feature_value) ) + # 使用geom_point配合jitterdodge位置调整 ggplot2::geom_point( position = ggplot2::position_jitterdodge( jitter.width = 0.25, # 水平抖动 jitter.height = 0.15, # 垂直抖动 dodge.width = 0.8 # 避让宽度 ), alpha = 0.8, size = 1.8 ) + # 添加参考线 ggplot2::geom_vline( xintercept = 0, linetype = "dashed", color = "black", linewidth = 0.7 ) + # 优化颜色映射 ggplot2::scale_color_gradient2( low = "#1E88E5", # 鲜亮的蓝色 mid = "#FFFFFF", # 白色中间色 high = "#E53935", # 鲜亮的红色 midpoint = median(plot_data$feature_value, na.rm = TRUE), name = "Feature Value", guide = guide_colorbar( barwidth = 12, barheight = 0.8, title.position = "top", title.hjust = 0.5 ) ) + # 坐标轴与标题设置 ggplot2::labs( x = "SHAP Value (Impact on Model Output)", y = "Features", title = "SHAP Feature Importance Beeswarm Plot", subtitle = "Each point represents a sample's feature contribution" ) + # 优化主题设置 ggplot2::theme_minimal(base_size = 14) + ggplot2::theme( plot.title = ggplot2::element_text( hjust = 0.5, face = "bold", size = 16, color = "black", margin = margin(b = 10) ), plot.subtitle = ggplot2::element_text( hjust = 0.5, size = 12, color = "black", margin = margin(b = 15) ), axis.title.x = ggplot2::element_text( face = "bold", size = 14, color = "black", margin = margin(t = 10) ), axis.title.y = ggplot2::element_text( face = "bold", size = 14, color = "black", margin = margin(r = 10) ), axis.text.x = ggplot2::element_text( size = 12, color = "black", face = "bold" ), axis.text.y = ggplot2::element_text( size = 12, color = "black", face = "bold" ), # 只有水平网格线 panel.grid.major.x = ggplot2::element_blank(), panel.grid.minor.x = ggplot2::element_blank(), panel.grid.major.y = ggplot2::element_line( color = "grey90", linewidth = 0.4 ), panel.grid.minor.y = ggplot2::element_blank(), axis.line.x = ggplot2::element_line( color = "black", linewidth = 0.6 ), legend.position = "bottom", legend.title = ggplot2::element_text(face = "bold", size = 12), legend.text = ggplot2::element_text(size = 10), plot.background = ggplot2::element_rect(fill = "white", color = NA), panel.background = ggplot2::element_rect(fill = "white", color = NA), plot.margin = ggplot2::margin(20, 20, 20, 20) ) + # 调整坐标轴范围 ggplot2::scale_x_continuous( expand = ggplot2::expansion(mult = c(0.03, 0.03)) ) # 6. 保存高分辨率图像 shap_dir <- file.path(output_dir, "HighRes_SHAP_Results") dir.create(shap_dir, showWarnings = FALSE, recursive = TRUE) png_file <- file.path(shap_dir, "SHAP_Beeswarm_Stable.png") ggplot2::ggsave( png_file, plot = p, width = 11, height = max(8, 0.6 * length(valid_features)), dpi = 900, bg = "white" ) pdf_file <- file.path(shap_dir, "SHAP_Beeswarm_Stable.pdf") ggplot2::ggsave( pdf_file, plot = p, width = 11, height = max(8, 0.6 * length(valid_features)), device = cairo_pdf, bg = "white" ) cat("Stable beeswarm plot saved:", png_file, "and", pdf_file, "\n") return(p) }, error = function(e) { cat("Stable beeswarm plot generation failed:", e$message, "\n") return(NULL) }) } # 17. 滞后期特征重要性分析函数(新增) analyze_lagged_feature_importance <- function(shap_analysis, xgb_results, output_dir) { cat("Analyzing lagged feature importance patterns...\n") lag_dir <- file.path(output_dir, "Lagged_Feature_Analysis") dir.create(lag_dir, showWarnings = FALSE, recursive = TRUE) tryCatch({ # 获取特征重要性数据 importance_data <- shap_analysis$importance_percentage # 识别滞后期特征 lag_features <- grep("_lag", importance_data$feature, value = TRUE) immediate_features <- grep("_0$|^[^_]*$", importance_data$feature, value = TRUE) cat(sprintf("Found %d lagged features and %d immediate features\n", length(lag_features), length(immediate_features))) if(length(lag_features) > 0) { # 创建滞后期特征分析图 lag_data <- importance_data[importance_data$feature %in% lag_features, ] # 提取滞后期信息 lag_data$lag_months <- sapply(lag_data$feature, function(x) { if(grepl("_lag5", x)) return(5) if(grepl("_lag3", x)) return(3) if(grepl("_lag1", x)) return(1) return(0) }) lag_data$base_feature <- gsub("_lag\\d+", "", lag_data$feature) lag_data$base_feature <- gsub("_0$", "", lag_data$base_feature) # 创建滞后期特征重要性图 p_lag <- ggplot(lag_data, aes(x = reorder(feature, percentage_contribution), y = percentage_contribution, fill = as.factor(lag_months))) + geom_col(alpha = 0.9) + scale_fill_manual( values = c("0" = "#1f77b4", "3" = "#ff7f0e", "5" = "#2ca02c"), name = "Lag (months)" ) + coord_flip() + labs( title = "Lagged Feature Importance Analysis", subtitle = "Impact of time-lagged environmental factors on dust storm prediction", x = "Features with Lag Periods", y = "SHAP Contribution Percentage (%)" ) + theme_minimal(base_size = 12) + theme( plot.title = element_text(hjust = 0.5, face = "bold", size = 14), plot.subtitle = element_text(hjust = 0.5, size = 10), axis.title = element_text(face = "bold", size = 12), axis.text = element_text(size = 10, color = "black"), legend.position = "bottom" ) # 保存高分辨率图像 ggsave(file.path(lag_dir, "Lagged_Feature_Importance.png"), p_lag, width = 10, height = 8, dpi = 900, bg = "white") # 按基础特征分组分析 feature_summary <- lag_data %>% group_by(base_feature) %>% summarise( max_contribution = max(percentage_contribution), best_lag = lag_months[which.max(percentage_contribution)], num_lags = n() ) %>% arrange(desc(max_contribution)) # 保存分析结果 write.csv(feature_summary, file.path(lag_dir, "Lagged_Feature_Summary.csv"), row.names = FALSE) write.csv(lag_data, file.path(lag_dir, "Lagged_Feature_Detailed_Analysis.csv"), row.names = FALSE) # 创建滞后期对比图 if(nrow(feature_summary) > 0) { p_comparison <- ggplot(feature_summary, aes(x = reorder(base_feature, max_contribution), y = max_contribution, fill = as.factor(best_lag))) + geom_col(alpha = 0.9) + scale_fill_manual( values = c("0" = "#1f77b4", "3" = "#ff7f0e", "5" = "#2ca02c"), name = "Optimal Lag (months)" ) + coord_flip() + labs( title = "Optimal Lag Periods for Environmental Factors", subtitle = "Best performing lag period for each environmental variable", x = "Environmental Factors", y = "Maximum SHAP Contribution (%)" ) + theme_minimal(base_size = 12) + theme( plot.title = element_text(hjust = 0.5, face = "bold", size = 14), plot.subtitle = element_text(hjust = 0.5, size = 10), axis.title = element_text(face = "bold", size = 12), axis.text = element_text(size = 10, color = "black"), legend.position = "bottom" ) ggsave(file.path(lag_dir, "Optimal_Lag_Periods.png"), p_comparison, width = 10, height = 8, dpi = 900, bg = "white") } cat("Lagged feature analysis completed successfully\n") return(list( lag_plot = p_lag, comparison_plot = if(exists("p_comparison")) p_comparison else NULL, summary = feature_summary )) } else { cat("No lagged features found for analysis\n") return(NULL) } }, error = function(e) { cat("Lagged feature analysis failed:", e$message, "\n") return(NULL) }) } # 18. 修复版高分辨率SHAP分析函数 create_highres_shap_analysis <- function(xgb_result, output_dir) { cat("Creating high-resolution SHAP analysis...\n") shap_dir <- file.path(output_dir, "HighRes_SHAP_Results") dir.create(shap_dir, showWarnings = FALSE, recursive = TRUE) tryCatch({ # 计算SHAP值 shap_values <- predict( xgb_result$model, xgb_result$X_train, predcontrib = TRUE, approxcontrib = FALSE ) # 提取特征贡献部分 if (ncol(shap_values) == length(xgb_result$features) + 1) { shap_matrix <- shap_values[, 1:length(xgb_result$features), drop = FALSE] } else { shap_matrix <- shap_values } # 确保列名正确 if (ncol(shap_matrix) == length(xgb_result$features)) { colnames(shap_matrix) <- xgb_result$features } # 计算特征重要性 feature_importance_percentage <- data.frame( feature = colnames(shap_matrix), mean_abs_shap = colMeans(abs(shap_matrix)), stringsAsFactors = FALSE ) %>% mutate( percentage_contribution = mean_abs_shap / sum(mean_abs_shap) * 100, rank = rank(-percentage_contribution) ) %>% arrange(desc(percentage_contribution)) # 创建高分辨率特征重要性图 create_highres_importance_plot(feature_importance_percentage, shap_dir) # 使用稳定版蜂群图函数 beeswarm_plot <- create_stable_beeswarm_plot( shap_matrix, xgb_result$X_train, feature_importance_percentage, output_dir) # 保存重要性结果 write.csv(feature_importance_percentage, file.path(shap_dir, "Feature_Importance_Percentage.csv"), row.names = FALSE) cat("高分辨率SHAP分析完成\n") return(list( importance_percentage = feature_importance_percentage, shap_matrix = shap_matrix, beeswarm_plot = beeswarm_plot )) }, error = function(e) { cat("高分辨率SHAP分析错误:", e$message, "\n") return(NULL) }) } # 19. 高分辨率特征重要性图函数 create_highres_importance_plot <- function(importance_data, output_dir) { cat("Creating high-resolution feature importance plot...\n") # 取前15个最重要的特征 plot_data <- head(importance_data, 15) # 创建高分辨率特征重要性图 p <- ggplot(plot_data, aes(x = percentage_contribution, y = reorder(feature, percentage_contribution))) + geom_col(fill = "#3498db", alpha = 0.9, width = 0.8) + geom_text(aes(label = sprintf("%.1f%%", percentage_contribution)), hjust = -0.1, size = 5, color = "#2c3e50", fontface = "bold") + scale_x_continuous(expand = expansion(mult = c(0, 0.1))) + labs( title = "Feature Contribution Analysis for Dust Storms", subtitle = "Feature importance ranking based on SHAP values", x = "Contribution Percentage (%)", y = "Environmental Drivers" ) + theme_minimal(base_size = 14) + theme( plot.title = element_text(hjust = 0.5, face = "bold", size = 16), plot.subtitle = element_text(hjust = 0.5, size = 12), axis.title = element_text(face = "bold", size = 14), axis.text = element_text(size = 12, color = "black"), axis.text.y = element_text(face = "bold"), plot.background = element_rect(fill = "white", color = NA), panel.grid.major = element_line(color = "grey90", linewidth = 0.2), panel.grid.minor = element_line(color = "grey95", linewidth = 0.1) ) # 保存高分辨率图像 ggsave(file.path(output_dir, "HighRes_Feature_Importance.png"), p, width = 12, height = 8, dpi = 900, bg = "white") # 同时保存PDF版本 ggsave(file.path(output_dir, "HighRes_Feature_Importance.pdf"), p, width = 12, height = 8, device = cairo_pdf, bg = "white") cat("高分辨率特征重要性图已保存\n") } # 20. 高分辨率ROC曲线函数 create_highres_roc_curve <- function(xgb_result, output_dir) { cat("Creating high-resolution ROC curve...\n") roc_dir <- file.path(output_dir, "HighRes_ROC_Curves") dir.create(roc_dir, showWarnings = FALSE, recursive = TRUE) tryCatch({ # 检查响应变量是否有两个水平 if(length(unique(xgb_result$y_train)) == 2 && length(unique(xgb_result$y_test)) == 2) { # 训练集ROC roc_train <- pROC::roc( response = xgb_result$y_train, predictor = xgb_result$train_pred_prob, levels = c(0, 1), direction = "<" ) # 测试集ROC roc_test <- pROC::roc( response = xgb_result$y_test, predictor = xgb_result$test_pred_prob, levels = c(0, 1), direction = "<" ) # 计算AUC值 train_auc <- round(roc_train$auc, 3) test_auc <- round(roc_test$auc, 3) # 创建ROC数据 roc_data <- rbind( data.frame( fpr = 1 - roc_train$specificities, tpr = roc_train$sensitivities, dataset = "Training Set" ), data.frame( fpr = 1 - roc_test$specificities, tpr = roc_test$sensitivities, dataset = "Test Set" ) ) # 创建高分辨率ROC曲线 p_roc <- ggplot(roc_data, aes(x = fpr, y = tpr, color = dataset)) + geom_line(size = 1.2) + geom_abline(slope = 1, intercept = 0, linetype = "dashed", color = "gray50", size = 0.8) + scale_color_manual( values = c("Training Set" = "#377EB8", "Test Set" = "#E41A1C"), labels = c( paste0("Training Set (AUC = ", train_auc, ")"), paste0("Test Set (AUC = ", test_auc, ")") ) ) + coord_equal() + theme_bw(base_size = 14) + labs( x = "False Positive Rate (1 - Specificity)", y = "True Positive Rate (Sensitivity)", color = "Dataset" ) + theme( axis.title = element_text(face = "bold", size = 14), axis.text = element_text(size = 12, color = "black"), legend.title = element_text(face = "bold", size = 12), legend.text = element_text(size = 11), legend.position = c(0.7, 0.3), legend.background = element_rect(fill = "white", color = "black"), panel.grid.major = element_line(color = "grey90", linewidth = 0.3), panel.grid.minor = element_line(color = "grey95", linewidth = 0.1), panel.border = element_rect(color = "black", fill = NA, linewidth = 0.8) ) + scale_x_continuous(expand = expansion(mult = c(0, 0.02))) + scale_y_continuous(expand = expansion(mult = c(0, 0.02))) # 保存高分辨率图像 ggsave(file.path(roc_dir, "HighRes_ROC_Curve.png"), p_roc, width = 8, height = 7, dpi = 900, bg = "white") # 同时保存PDF版本用于论文 ggsave(file.path(roc_dir, "HighRes_ROC_Curve.pdf"), p_roc, width = 8, height = 7, device = cairo_pdf, bg = "white") # 创建AUC比较表 auc_comparison <- data.frame( Dataset = c("Training Set", "Test Set"), AUC = c(train_auc, test_auc), Samples = c(length(xgb_result$y_train), length(xgb_result$y_test)) ) write.csv(auc_comparison, file.path(roc_dir, "AUC_Comparison.csv"), row.names = FALSE) cat("高分辨率ROC曲线已保存\n") cat(sprintf("训练集AUC: %.3f\n", train_auc)) cat(sprintf("测试集AUC: %.3f\n", test_auc)) return(p_roc) } else { cat("响应变量只有一个水平,跳过ROC曲线绘制\n") return(NULL) } }, error = function(e) { cat("创建高分辨率ROC曲线时出错:", e$message, "\n") return(NULL) }) } # 21. 高分辨率模型诊断函数 create_highres_classification_diagnostics <- function(xgb_result, output_dir) { cat("Creating high-resolution classification model diagnostics...\n") diagnostics_dir <- file.path(output_dir, "HighRes_Model_Diagnostics") dir.create(diagnostics_dir, showWarnings = FALSE, recursive = TRUE) plots <- list() # 1. 高分辨率ROC曲线 roc_plot <- create_highres_roc_curve(xgb_result, diagnostics_dir) if(!is.null(roc_plot)) { plots$roc_plot <- roc_plot } return(plots) } # 22. 包检查函数 check_and_install_packages <- function() { cat("Checking and installing required packages...\n") required_packages <- c('xgboost', 'caret', 'Metrics', 'tibble', 'dplyr', 'ggplot2', 'pROC', 'viridis', 'ggpubr', 'ggbeeswarm', 'terra', 'tidyverse', 'lubridate', 'patchwork', 'rsample', 'pROC', 'MLmetrics', 'randomForest', 'sf') for(pkg in required_packages) { if(!require(pkg, character.only = TRUE, quietly = TRUE)) { install.packages(pkg, dependencies = TRUE) library(pkg, character.only = TRUE) cat("Installed and loaded:", pkg, "\n") } else { cat("Package already installed:", pkg, "\n") } } } # 23. 综合报告生成函数 generate_comprehensive_report <- function(xgb_results, shap_results, data_info, output_dir) { cat("Generating comprehensive analysis report...\n") report_dir <- file.path(output_dir, "Comprehensive_Report") dir.create(report_dir, showWarnings = FALSE, recursive = TRUE) # 创建简单的文本报告 report_content <- paste( "Dust Storm Probability XGBoost+SHAP Analysis Report", paste("Generated time:", Sys.time()), paste("Data processing level:", data_info$data_type), paste("Total samples:", data_info$sample_count), paste("Number of features:", data_info$feature_count), "", "Model Performance:", paste("Training Accuracy:", round(xgb_results$metrics$train_accuracy, 4)), paste("Test Accuracy:", round(xgb_results$metrics$test_accuracy, 4)), paste("Test AUC:", ifelse(is.na(xgb_results$metrics$test_auc), "N/A", round(xgb_results$metrics$test_auc, 4))), "", "Analysis completed successfully with high-resolution outputs.", sep = "\n" ) writeLines(report_content, file.path(report_dir, "Analysis_Report.txt")) cat("Comprehensive report generated:", report_dir, "\n") } # 24. 主分析函数 - 修复版 main_dust_storm_analysis_highres <- function(output_base = "I:/DustStorm_Analysis/XGBoost_SHAP_HighRes", processing_level = "pixel", max_pixels = 100, include_lag_features = TRUE, feature_set = "hybrid_optimal") { # 检查并安装必要包 check_and_install_packages() cat("Starting improved dust storm probability XGBoost analysis (high-resolution version)...\n") cat("Data processing level:", processing_level, "\n") cat("Output directory:", output_base, "\n") cat("Include lagged features:", include_lag_features, "\n") cat("Feature set used:", feature_set, "\n") # 创建输出目录 dir.create(output_base, showWarnings = FALSE, recursive = TRUE) # 设置terra临时目录 temp_dir <- file.path(output_base, "temp_terra") dir.create(temp_dir, showWarnings = FALSE) terraOptions(tempdir = temp_dir) cat("terra temporary directory set to:", temp_dir, "\n") tryCatch({ # 1. 数据读取 cat("=== Step 1: Reading data ===\n") raw_data <- read_all_dust_data_annual() if(is.null(raw_data) || length(raw_data$data) == 0) { stop("Failed to read any data, please check file paths") } # 2. 统一分辨率 cat("=== Step 2: Resampling to common resolution ===\n") processed_data <- resample_to_common_resolution(raw_data$data) # 3. 数据预处理(根据选择的级别) cat("=== Step 3: Preprocessing data ===\n") if(processing_level == "pixel") { analysis_data <- preprocess_pixel_level_data(processed_data, raw_data$time_index, max_pixels_per_month = max_pixels) data_type <- "pixel_level" } else { analysis_data <- preprocess_regional_data(processed_data, raw_data$time_index) data_type <- "regional_level" } if(is.null(analysis_data) || nrow(analysis_data) == 0) { stop("Data preprocessing failed, cannot continue analysis") } cat("Original sample size:", nrow(analysis_data), "\n") cat("Data years:", paste(sort(unique(analysis_data$year)), collapse = ", "), "\n") # 4. 构建滞后期特征 if(include_lag_features) { cat("=== Step 4: Creating lagged features ===\n") analysis_data <- create_lagged_features(analysis_data) cat("Lagged features created successfully\n") cat("Available features:", paste(setdiff(colnames(analysis_data), c("year", "month", "date", "pixel_id", "x", "y")), collapse = ", "), "\n") } # 5. 创建沙尘暴分类标签 cat("=== Step 5: Creating dust storm labels ===\n") analysis_data <- create_dust_storm_labels(analysis_data, dcmd_threshold_method = "quantile", threshold_value = 0.75) # 6. 划分训练集和测试集 cat("=== Step 6: Splitting data into training and test sets ===\n") split_data <- split_train_test_data(analysis_data, test_size = 0.3) # 检查划分结果 if(is.null(split_data$train_data) || is.null(split_data$test_data)) { stop("Data splitting failed") } cat(sprintf("Training set samples: %d\n", nrow(split_data$train_data))) cat(sprintf("Test set samples: %d\n", nrow(split_data$test_data))) cat(sprintf("Training set dust storm ratio: %.2f%%\n", mean(split_data$train_data$dust_storm) * 100)) cat(sprintf("Test set dust storm ratio: %.2f%%\n", mean(split_data$test_data$dust_storm) * 100)) # 7. XGBoost分类分析 cat("=== Step 7: Performing XGBoost classification ===\n") if(feature_set == "compare_all") { # 比较所有特征集 comparison_results <- compare_feature_sets(split_data$train_data, split_data$test_data) xgb_results <- comparison_results$results$hybrid_optimal } else { # 使用指定特征集 xgb_results <- perform_xgboost_classification_improved( split_data$train_data, split_data$test_data, feature_set_name = feature_set ) } # 8. 高分辨率SHAP分析 cat("=== Step 8: Performing SHAP analysis ===\n") shap_analysis <- create_highres_shap_analysis(xgb_results, output_base) # 9. 滞后期特征分析 if(include_lag_features && !is.null(shap_analysis)) { cat("=== Step 9: Analyzing lagged feature importance ===\n") analyze_lagged_feature_importance(shap_analysis, xgb_results, output_base) } # 10. 高分辨率模型诊断 cat("=== Step 10: Creating model diagnostics ===\n") diagnostic_plots <- create_highres_classification_diagnostics(xgb_results, output_base) # 11. 生成综合报告 cat("=== Step 11: Generating comprehensive report ===\n") data_info <- list( sample_count = nrow(analysis_data), feature_count = length(setdiff(colnames(analysis_data), c("year", "month", "date", "pixel_id", "x", "y", "dust_storm"))), data_type = data_type ) generate_comprehensive_report(xgb_results, shap_analysis, data_info, output_base) # 完成输出 separator_line <- paste(rep("=", 80), collapse = "") cat("\n", separator_line, "\n", sep = "") cat("=== Improved Dust Storm Probability XGBoost+SHAP Analysis Completed (High-Resolution) ===\n") cat(separator_line, "\n") cat("Output directory:", output_base, "\n") cat("Data processing level:", processing_level, "\n") cat("Total sample size:", nrow(analysis_data), "\n") cat("Training set samples:", nrow(split_data$train_data), "\n") cat("Test set samples:", nrow(split_data$test_data), "\n") cat("Lagged features included:", include_lag_features, "\n") cat("Feature set used:", feature_set, "\n") cat("Main improvements:\n") cat("- Complete data processing pipeline added\n") cat("- High-resolution output (900 DPI) with English labels\n") cat("- Journal-style formatting for all figures\n") cat("- All known errors fixed\n") return(list( xgb_results = xgb_results, shap_analysis = shap_analysis, diagnostics = diagnostic_plots, analysis_data = analysis_data, split_data = split_data, processing_level = processing_level, include_lag_features = include_lag_features )) }, error = function(e) { cat("\n=== Analysis Failed ===\n") cat("Error message:", e$message, "\n") cat("Traceback:\n") traceback() return(NULL) }) } # 执行分析 cat("High-resolution dust storm XGBoost+SHAP analysis script loaded\n") cat("Run the following command to start analysis:\n") cat("results_highres <- main_dust_storm_analysis_highres(processing_level = 'pixel', max_pixels = 100, include_lag_features = TRUE)\n") # 开始分析 cat("\nStarting high-resolution analysis...\n") results_highres <- main_dust_storm_analysis_highres( output_base = "I:/DustStorm_Analysis/XGBoost_SHAP_HighRes", processing_level = "pixel", max_pixels = 100, include_lag_features = TRUE, feature_set = "hybrid_optimal" )运行后报错,请给出修改后的完整代码:=== Step 8: Performing SHAP analysis === Creating high-resolution SHAP analysis... Creating high-resolution feature importance plot... 高分辨率特征重要性图已保存 Creating stable high-resolution beeswarm plot using standard ggplot... Stable beeswarm plot generation failed: 缺少参数"observed",也缺失默认值
最新发布
11-15
``` install.packages("pacman") pacman::p_load(randomForest,caret,pROC) install.packages("randomForest") library(randomForest) install.packages("caret") library(caret) install.packages("pROC") library(pROC) install.packages("lava") library(lava) #lasso回归筛选数据集随机森林 completed_copd <- read.csv("C:\\Users\\29930\\Desktop\\COPD2.csv") completed_copd$COPD <- as.factor(completed_copd$COPD) library(caret) set.seed(40705) trainlist <- createDataPartition(completed_copd$COPD,p=0.7,list = FALSE) trainset <- completed_copd[trainlist,] testset <- completed_copd[-trainlist,] library(randomForest) set.seed(40705) rf.train <- randomForest(as.factor(COPD) ~.,data = trainset,importance = TRUE) rf.train # 10折交叉验证 library(pROC) library(MLmetrics) cv <- trainControl(method = "cv", number = 10, classProbs = TRUE, summaryFunction = twoClassSummary) results <- train(COPD ~ ., data = trainset, method = "rf", trControl = cv) # 输出交叉验证结果 results plot(rf.train, main = "图1 lasso筛选变量数据集的随机森林与误差关系图") predictions <- predict(rf.train,testset,type = "class") predictions confMatrix <- table(testset$COPD, predictions) acc <- sum(predictions ==testset$COPD)/nrow(testset) print(paste("Accuracy",acc)) set.seed(40705) rf.test <- predict(rf.train, newdata = testset, type = "class") rf.cf <- caret::confusionMatrix(as.factor(rf.test),as.factor(testset$COPD)) rf.test2 <- predict(rf.train, newdata = testset, type = "prob") head(rf.test2) library(pROC) ROC.rf <- multiclass.roc(testset$COPD,rf.test2,plot = TRUE, print.auc = TRUE, legacy.axes = TRUE) head(ROC.rf) #计算权值 varImpPlot(rf.train) importance <- importance(rf.train) imp_df <- data.frame(feature=row.names(importance), importance=importance[,1]) imp_df$weight <- imp_df$importance/sum(imp_df$importance) imp_df$score <- imp_df$weight*100 print(imp_df)```增加一个评分系统以预测COPD,并增加可视化和输出公式
03-30
# 加载必要包 library(poLCA) library(dplyr) library(tidyr) library(ggplot2) # 设置路径与变量 file_path <- "D:/SHUJU/car_and_ebike.csv" vars <- c("Driver.gender", "Driver.identity", "Passenger.car.state", "Weekend", "Road.condition.classification", "Crash.type", "Weather", "Visibility", "Lighting.condition", "Road.functional.class", "Rider.age", "Physical.separation.of.the.road", "Rider.gender", "Rider.hurt.part") # 读取并预处理数据 data <- read.csv(file_path) data <- data[vars] # 转换为从1开始的分类变量 data[] <- lapply(data, function(x) { x <- as.factor(x) x <- as.numeric(as.factor(x)) return(x) }) # 构建 LCA 公式 f <- as.formula(paste("cbind(", paste(vars, collapse = ","), ") ~ 1")) # 初始化存储指标 fit_stats <- data.frame() models <- list() N <- nrow(data) # 样本量,用于 CAIC # 拟合 1~10 类的模型 for (k in 1:10) { cat("拟合 LCA 模型,类别数 =", k, "\n") set.seed(123) lca_model <- poLCA(f, data, nclass = k, na.rm = FALSE, verbose = FALSE) models[[k]] <- lca_model # 熵R²计算 posterior <- lca_model$posterior entropy <- -rowSums(posterior * log(posterior + 1e-10)) max_entropy <- log(ncol(posterior)) entropy_r2 <- 1 - mean(entropy) / max_entropy # CAIC 计算 ll <- lca_model$llik num_params <- lca_model$npar caic <- -2 * ll + num_params * (log(N) + 1) # 存储结果 fit_stats <- rbind(fit_stats, data.frame( K = k, BIC = lca_model$bic, AIC = lca_model$aic, CAIC = caic, EntropyR2 = entropy_r2 )) } # ---------- 图 1:AIC、BIC、CAIC ---------- # 手动设置 y 轴范围,使图更“平缓” y_min <- min(fit_plot$Value) * 0.98 # 稍微留点空间 y_max <- max(fit_plot$Value) * 1.02 p1 <- ggplot(fit_plot, aes(x = K, y = Value, color = Metric)) + geom_line(size = 1.2) + geom_point(size = 2.5) + scale_x_continuous(breaks = 1:10) + labs(x = "Number of clusters", y = "Information criterion") + coord_cartesian(ylim = c(y_min, y_max)) + # 控制纵轴显示范围 theme_minimal(base_size = 14) + theme( plot.title = element_blank(), legend.position = c(0.82, 0.85), legend.background = element_rect(fill = alpha("white", 0.6), color = NA), legend.title = element_blank() ) # ---------- 图 2:熵 R² ---------- fit_stats_r2 <- fit_stats %>% filter(K > 1) p2 <- ggplot(fit_stats_r2, aes(x = K, y = EntropyR2)) + geom_line(color = "#1f77b4", size = 1.2) + geom_point(color = "#1f77b4", size = 2.5) + scale_x_continuous(breaks = 2:10) + ylim(0, 1) + labs(x = "Number of clusters", y = "Entropy R²") + theme_minimal(base_size = 14) + theme( plot.title = element_blank() ) # 调整图形窗口大小(RStudio 中有效) options(repr.plot.width=10, repr.plot.height=5) # 显示图形 print(p1) print(p2) 这是我的代码,已经选出了最佳的k值,想要继续下一步的聚类,你可以帮我继续完善代码吗
06-13
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值