GridSearchCV是参数优化工具,优化周期较长:
from sklearn.model_selection import GridSearchCV
import time
from xgboost import XGBRegressor
from sklearn.metrics import mean_squared_error, r2_score
# 简化网格搜索参数
param_grid = {
'max_depth': [3, 5], # 减少深度选项
'learning_rate': [0.05, 0.1], # 减少学习率选项
'n_estimators': [500] # 固定树的数量
}
# 设置超时和更少的交叉验证折数
grid = GridSearchCV(
XGBRegressor(verbosity=0), # 减少输出
param_grid,
cv=3, # 保持3折交叉验证
scoring='neg_mean_squared_error',
n_jobs=2, # 限制并行作业数
verbose=1 # 显示进度
)
# 添加超时保护
try:
start_time = time.time()
# 设置最大运行时间为5分钟
max_time = 300 # 秒
grid.fit(X_train_final, y_train)
best_model = grid.best_estimator_
print(f"最佳参数: {grid.best_params_}")
print(f"最佳得分: {-grid.best_score_:.2f}")
# 使用最佳模型进行预测
best_pred = best_model.predict(X_test_final)
print(f"最佳模型 MSE: {mean_squared_error(y_test, best_pred):.2f}")
print(f"最佳模型 R²: {r2_score(y_test, best_pred):.2f}")
except KeyboardInterrupt:
print("网格搜索被手动中断,使用已训练的基本模型")
best_model = model # 使用之前训练的基本模型
except Exception as e:
print(f"发生错误: {e}")
best_model = model # 使用之前训练的基本模型
优化前:
优化后:
代码解读:
from sklearn.model_selection import GridSearchCV
import time
-
导入
GridSearchCV
模块,用于超参数优化。 -
导入
time
模块,用于记录运行时间。
# 简化网格搜索参数
param_grid = {
'max_depth': [3, 5], # 减少深度选项
'learning_rate': [0.05, 0.1], # 减少学习率选项
'n_estimators': [500] # 固定树的数量
}
- 定义超参数网格
param_grid
,用于限制网格搜索的范围。-
max_depth
: 树的最大深度,取值为[3, 5]
。 -
learning_rate
: 学习率,取值为[0.05, 0.1]
。 -
n_estimators
: 树的数量,固定为500
。
-
# 设置超时和更少的交叉验证折数
grid = GridSearchCV(
XGBRegressor(verbosity=0), # 减少输出
param_grid,
cv=3, # 保持3折交叉验证
scoring='neg_mean_squared_error',
n_jobs=2, # 限制并行作业数
verbose=1 # 显示进度
)
- 初始化
GridSearchCV
对象,用于超参数优化。-
使用
XGBRegressor
作为基础模型,设置verbosity=0
以减少输出。 -
使用
param_grid
作为超参数网格。 -
设置
cv=3
,表示使用 3 折交叉验证。 -
使用
scoring='neg_mean_squared_error'
作为评分标准。 -
设置
n_jobs=2
,允许并行运行 2 个作业。 -
设置
verbose=1
,显示网格搜索的进度。
-
# 添加超时保护
try:
start_time = time.time()
# 设置最大运行时间为5分钟
max_time = 300 # 秒
grid.fit(X_train_final, y_train)
best_model = grid.best_estimator_
print(f"最佳参数: {grid.best_params_}")
print(f"最佳得分: {-grid.best_score_:.2f}")
# 使用最佳模型进行预测
best_pred = best_model.predict(X_test_final)
print(f"最佳模型 MSE: {mean_squared_error(y_test, best_pred):.2f}")
print(f"最佳模型 R²: {r2_score(y_test, best_pred):.2f}")
-
使用
try
块来捕获异常。 -
记录开始时间
start_time
。 -
设置最大运行时间为 5 分钟(300 秒)。
-
调用
grid.fit()
方法,开始网格搜索。 -
获取最佳模型
best_model
。 -
打印最佳参数和最佳得分。
-
使用最佳模型对测试集进行预测,并计算均方误差(MSE)和决定系数(R²)。
except KeyboardInterrupt:
print("网格搜索被手动中断,使用已训练的基本模型")
best_model = model # 使用之前训练的基本模型
except Exception as e:
print(f"发生错误: {e}")
best_model = model # 使用之前训练的基本模型
-
捕获
KeyboardInterrupt
异常,表示用户手动中断了网格搜索。 -
捕获其他异常,并打印错误信息。
-
在异常情况下,使用之前训练的基本模型
model
。
总结
这段代码的主要目的是使用网格搜索优化 XGBRegressor
的超参数,同时设置超时保护机制以防止长时间运行。代码逻辑清晰,注释详细,适合用于实际项目中的超参数优化。