最近开发公司的机器学习平台的XGBoost控件。结果报了一个bug,说“feature_names mismatch”。
现在我们来复现这个bug:
import xgboost as xgb
import pandas as pd
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
X, y = make_classification(n_features=4)
X_train, X_test, y_train, y_test = train_test_split(X, y)
features = ['a1', 'a2', 'a3', 'a4']
df = pd.DataFrame(data=X_train, columns=features)
df['y'] = y_train
X_train = df[['a1', 'a2', 'a3', 'a4']]
y_train = df['y']
model = xgb.XGBClassifier()
model.fit(X_train, y_train)
以上代码随机生成了一个分类数据集,它的特征的名字是a1, a2, a3, a4。我们用pandas的数据集,而不是numpy的数列,传给XGB的fit函数来训练模型。这里还没有出现bug。
然而在预测的时候,出现了bug。
model.predict(X_test)
bug的描述如下:
-----------------

在构建机器学习平台时遇到了XGBoost的'feature_names mismatch'错误。通过在预测时添加validate_features=False参数可以暂时解决,但这可能引发更多问题。为保持系统一致性,采用适配器设计模式,创建一个adapter类包装XGBClassifier,修改其predict方法以在不指定validate_features时也能正常工作。这种方式避免了直接修改XGB源码的复杂性,确保平台中所有组件遵循相同接口规则。
最低0.47元/天 解锁文章
4855





