scikit_learn中fit()/transform()/fit_transform()区别和联系

该博客介绍了如何使用scikit-learn库中的SimpleImputer类来处理数据集中的缺失值。通过fit()方法学习数据集特征,如中位数,然后使用transform()或fit_transform()将学到的特征应用到数据中,填充缺失值。fit_transform()适用于训练集,而transform()则用于测试集以保持标准化的一致性。强调了在预处理阶段,对训练集和测试集的处理必须保持一致。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

函数功能解释

fit() 

根据训练集数据学习得到数据集的特征,比如均值、中位数、标准差等等

transform()

将fit()学到数据集特征,应用到数据集,比如学习到数据集平均数为6,应用到填充数据中的缺失值

fit_transform() = fit()+transform()

即将从数据集中学到的特征(均值、中位数、标准差)应用到数据集中

举例

此处以使用均值填充缺失值举例

>>> import numpy as np
>>> from sklearn.impute import SimpleImputer

# 学习方法(策略)
>>> imp_mean = SimpleImputer(missing_values=np.nan, strategy='median') 

# 学习方法(策略)从下面数据集中进行学习
[[ 7.  2.  3.]
 [ 4. nan  6.]
 [10.  5.  9.]]

>>> imp_mean_fit = imp_mean.fit([[7, 2, 3], [4, np.nan, 6], [10, 5, 9]]) 

# 学习方法(策略)从实际数据中进行学习,学习到的第1/2/3列中位数 分别是7.0,3.5,6.0
>>> imp_mean_fit.statistics_ 

array([7. , 3.5, 6. ])

# 将学习到的结果(第1/2/3列中位数 分别是7.0,3.5,6.0)应用到数据集X,原来的np.nan分别被中位数替换
>>> X = [[np.nan, 2, 3], [4, np.nan, 6], [10, np.nan, 9]]
>>> print(imp_mean_fit.transform(X))

[[ 7.   2.   3. ]
 [ 4.   3.5  6. ]
 [10.   3.5  9. ]]

应用

fit()+transform() 以及 fit_transform()只能应用在训练集,一般不能对测试集进行使用,测试集合一般用transfrom(),即只能将训练集提取到特征应用到训练集及测试集。原因是如果fit_transfrom(trainData)后,使用fit_transform(testData)而不transform(testData),虽然也能归一化,但是两个结果不是在同一个“标准”下的,具有明显差异。

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值