这里的代码# 4. 随机森林建模
# 初始化基础模型
rf = RandomForestClassifier(random_state=42, n_jobs=-1)
# 5. 超参数优化
param_grid = {
'n_estimators': [100, 200],
'max_depth': [10, 20, None],
'min_samples_split': [2, 5],
'min_samples_leaf': [1, 2]
}
grid_search = GridSearchCV(
estimator=rf,
param_grid=param_grid,
cv=3,
scoring='f1_weighted',
verbose=1
)
grid_search.fit(X_train, y_train)
# 获取最佳模型
best_rf = grid_search.best_estimator_
# 6. 模型评估
# 预测测试集
y_pred = best_rf.predict(X_test)
# 混淆矩阵
cm = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Confusion Matrix')
plt.show()出现如下错误:---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-4-6bfca005f6db> in <module>
18 verbose=1
19 )
---> 20 grid_search.fit(X_train, y_train)
21
22 # 获取最佳模型
/ddhome/.pyenv/versions/archiconda3/lib/python3.7/site-packages/sklearn/model_selection/_search.py in fit(self, X, y, groups, **fit_params)
737 refit_start_time = time.time()
738 if y is not None:
--> 739 self.best_estimator_.fit(X, y, **fit_params)
740 else:
741 self.best_estimator_.fit(X, **fit_params)
/ddhome/.pyenv/versions/archiconda3/lib/python3.7/site-packages/sklearn/ensemble/_forest.py in fit(self, X, y, sample_weight)
293 """
294 # Validate or convert input data
--> 295 X = check_array(X, accept_sparse="csc", dtype=DTYPE)
296 y = check_array(y, accept_sparse='csc', ensure_2d=False, dtype=None)
297 if sample_weight is not None:
/ddhome/.pyenv/versions/archiconda3/lib/python3.7/site-packages/sklearn/utils/validation.py in check_array(array, accept_sparse, accept_large_sparse, dtype, order, copy, force_all_finite, ensure_2d, allow_nd, ensure_min_samples, ensure_min_features, warn_on_dtype, estimator)
529 array = array.astype(dtype, casting="unsafe", copy=False)
530 else:
--> 531 array = np.asarray(array, order=order, dtype=dtype)
532 except ComplexWarning:
533 raise ValueError("Complex data not supported\n"
/ddhome/.pyenv/versions/archiconda3/lib/python3.7/site-packages/numpy/core/_asarray.py in asarray(a, dtype, order)
83
84 """
---> 85 return array(a, dtype, copy=False, order=order)
86
87
ValueError: could not convert string to float: '10+ years'