这个错误纠结了我一个小时,在网上都没找到答案,我的源代码如下,import就不贴了。
data = pd.read_csv(r'F:data\titanic\train.csv', na_values='?')
y = data['Survived']
x = data.drop(columns=['Survived','PassengerId', 'Pclass', 'Name', 'Ticket'])
x_train, x_test, y_train, y_test = train_test_split(x, y, random_state=42)
col_cat = ['Sex', 'Embarked']
col_num = ['Age', 'SibSp', 'Parch', 'Fare']
pipe_cat = make_pipeline(SimpleImputer(strategy='constant'), OneHotEncoder(handle_unknown='ignore'))
pipe_num = make_pipeline(SimpleImputer(), StandardScaler())
preprocessor = make_column_transformer((col_cat, pipe_cat), (col_num, pipe_num))
pipe = make_pipeline(preprocessor, LogisticRegression(solver='lbfgs'))
pipe.fit(x_train, y_train)
accuracy = pipe.score(x_test, y_test)
print('Accuracy score of the {} is {:.2f}'.format(pipe.__class__.__name__, accuracy))
编译的时候他非要报错,下面的错误:
Traceback (most recent call last):
File "E:\Anaconda\install\envs\pytorch\lib\site-packages\sklearn\pipeline.py", line 390, in fit
Xt = self._fit(X, y, **fit_params_steps)
File "E:\Anaconda\install\envs\pytorch\lib\site-packages\sklearn\pipeline.py", line 355, in _fit
**fit_params_steps[name],
File "E:\Anaconda\install\envs\pytorch\lib\site-packages\joblib\memory.py", line 349, in __call__
return self.func(*args, **kwargs)
File "E:\Anaconda\install\envs\pytorch\lib\site-packages\sklearn\pipeline.py", line 893, in _fit_transform_one
res = transformer.fit_transform(X, y, **fit_params)
File "E:\Anaconda\install\envs\pytorch\lib\site-packages\sklearn\compose\_column_transformer.py", line 671, in fit_transform
self._validate_transformers()
File "E:\Anaconda\install\envs\pytorch\lib\site-packages\sklearn\compose\_column_transformer.py", line 339, in _validate_transformers
"specifiers. '%s' (type %s) doesn't." % (t, type(t))
TypeError: All estimators should implement fit and transform, or can be 'drop' or 'passthrough' specifiers. '['Sex', 'Embarked']' (type <class 'list'>) doesn't.
结果傻逼的看了一下官方的 API 文档才发现, make_column_transformer
这个管道连接其必须的 tuple
中需要管道在前,列名在后,所以只需要将这俩位置换一下就行
preprocessor = make_column_transformer((pipe_cat, col_cat), (pipe_num, col_num))
气死老娘了!