
☞ ░ 前往老猿Python博客 ░ https://blog.youkuaiyun.com/LaoYuanPython
一、概述
在机器学习中,数据预处理是非常重要的工作,常规执行顺序:数据分割 → 2. 数据清洗 → 3. 特征编码 → 4. 特征构造 → 5. 异常值处理 → 6. 特征缩放 → 7. 降维,可以看到数据分割是数据预处理的第一步。
任何基于统计的预处理(如均值填充、标准化)都应仅在训练集上计算参数,再应用到测试集,这样测试集才能起到评估训练后模型的效果。
Scikit-learn 提供了多种数据分割方法,适用于不同机器学习场景,包括简单随机分割train_test_split、标准K折交叉验证KFold、分层K折交叉验证StratifiedKFold、重复多次K折验证RepeatedKFold/RepeatedStratifiedKFold、时间序列分割TimeSeriesSplit、分组分割GroupKFold、LeaveOneGroupOut、留一法LeaveOneOut/LeavePOut。
这些方法一些使用上的建议:
- train_test_split:快速验证的简单随机数据分割
- KFold:适合回归任务,注意大数据集可减少n_splits
- StratifiedKFold:适合数据集不均衡分类 ,分类任务优先使用
- TimeSeriesSplit:适合时间序列
- GroupKFold:适合分组数据
- LeaveOneOut:适合小数据集
- RepeatedKFold:适合结果不稳定时使用
这些方法都可通过sklearn.model_selection导入,是模型评估的重要工具。根据数据特性和任务需求选择合适的分割方式,能显著提高模型评估的可靠性。
任何基于统计的预处理(如均值填充、标准化)都应仅在训练集上计算参数,再应用到测试集,这样测试集才能起到评估训练后模型的效果。
下面是老猿介绍sklearn中数据分割的相关方法的博文列表:
- 《Scikit-learn(sklearn)机器学习的数据分割方法–train_test_split和KFold https://blog.youkuaiyun.com/LaoYuanPython/article/details/149479366》介绍了简单数据分割train_test_split方法和K折交叉验证KFold类
- 《Scikit-learn(sklearn)机器学习的数据分割方法–分层K折交叉验证StratifiedKFold https://blog.youkuaiyun.com/LaoYuanPython/article/details/150060086》》介绍分层K折交叉验证StratifiedKFold的数据分割方法
- 《Scikit-learn(sklearn)机器学习的数据分割方法–重复 K 折RepeatedKFold https://blog.youkuaiyun.com/LaoYuanPython/article/details/149324637》介绍重复 K 折RepeatedKFold 的数据分割方法。
- …
二、分层K折交叉验证StratifiedKFold介绍
2.1、方法原理
StratifiedKFold 是 scikit-learn 提供的一种交叉验证方法,在 K-Fold 的基础上进行了改进,将数据集分成 k 个折叠(folds),确保每个折叠中各类别样本的比例与原始数据集中的比例相同,保持了每个折叠中类别的分布比例与原始数据集一致,特别适用于分类问题,尤其是类别分布不平衡的数据集,避免因随机分割导致某些折中缺少少数类样本。在类别不平衡情况下,比kfold能提供更可靠的模型评估。
因回归目标为连续值无法分层,StratifiedKFold仅用于分类任务,不适用于回归任务。
StratifiedKFold在分类任务中,保持每折的类别比例与原始数据集一致,特别适用于类别不均衡数据(如正负样本比例 9:1)和类别分布不均衡的分类任务(如医疗诊断、欺诈检测)。
2.2、StratifiedKFold初始化构造方法
分层K折交叉验证在sklearn中的实现使用StratifiedKFold类实现,其初始化参数如下:
- n_splits:分割的折数 K(必须 ≥ 2),整型,默认值5,推荐值5 或 10(平衡计算成本和评估可靠性)。若设为 n_splits=5,则生成 5 折,训练集占 80%,验证集占 20%。对小数据集可增加折数(如 n_splits=10)以提高评估稳定性,对大数据集减少折数(如 n_splits=3)以降低计算成本;
- shuffle: 是否在分割前打乱数据顺序,布尔型,默认值False,表示按数据原始顺序分割(适用于时间序列以外的场景),为True则先打乱数据再分割(避免因原始顺序引入偏差);
- random_state:随机种子(仅当 shuffle=True 时生效),整型,默认值为None,固定随机种子为某个数值可确保每次运行结果一致
这三个参数在对象初始化后,可以通过n_splits、shuffle和random_state属性去访问。
2.3、StratifiedKFold的方法
StratifiedKFold提供了两个方法,一个是数据集划分的方法split,一个是获取交叉验证分割的数量get_n_splits。
2.3.1、split方法
split方法用于生成训练集和测试集的索引。
-
语法:split(X, Y, groups=None),其中:
1). X为特征数据集,Y为目标标签数据集
2). groups可选,一维数组,用于分组的数据,即每个样本对应有一个分组ID(例如同一个患者的多次测量、同一家庭的多个成员,样本数可能是多个,但它们的分组是一个),为 X 中每个样本分配一个组标识符,标识样本所属的逻辑分组。通常不使用,主要用于确保同一组的数据在交叉验证时不被拆分到不同的折叠中。 -
返回结果:返回一个生成器,这个生成器用于生成训练集和测试集数据的索引元组
2.3.2、get_n_splits方法
语法:get_n_splits(X=None, Y=None, groups=None)
返回:整数,表示分割数(n_splits)
在 StratifiedKFold 中,n_splits 属性和 get_n_splits() 方法虽然都能获取交叉验证的折数,既然有了属性值直接访问为什么还要提供方法呢?这是因为这个方法是 scikit-learn 交叉验证器的统一接口(所有交叉验证类都必须实现,Python的鸭子类型特征),目的是为了与其他组件(如 GridSearchCV、cross_val_score)兼容,这些组件通过调用方法而非直接访问属性来获取折数。
get_n_splits() 方法的参数(如 X、Y、groups)为未来扩展预留了空间。虽然 StratifiedKFold 中这些参数通常无用,但在其他验证器中可能影响实际折数,所以为了实现通用的接口,定义了这三个参数。
三、案例
下面实现了加载鸢尾花数据集并使用StratifiedKFold进行数据分割得到案例:
from sklearn.model_selection import StratifiedKFold
from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression
import numpy as np
# 加载数据
X, Y = load_iris(return_X_y=True) #当 return_X_y=True 时,函数会直接返回一个元组 (X, y)
# 初始化 StratifiedKFold
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=20209)
print(skf.n_splits,skf.shuffle,skf.random_state)
# 初始化模型
model = LogisticRegression()
for fold, (train_idx, test_idx) in enumerate(skf.split(X,Y), 1):
print(f"\n第 {fold}折:")
print(f" 训练集中的数据索引情况: {train_idx[:5]}... (总数: {len(train_idx)})")
print(f" 验证集中的数据索引情况: {test_idx[:5]}... (总数: {len(test_idx)})")
#print(f" 训练集中的数据: {Y[train_idx]}")
print(f" 验证集中的数据: {Y[test_idx]}")
# 检查分层情况
counts = np.bincount(Y[train_idx])
print(" 训练集中的类别和数据分布数:",np.arange(len(counts)), counts)
counts = np.bincount(Y[test_idx])
print(" 验证集中的类别和数据分布数:", np.arange(len(counts)), counts)
# 进行交叉验证
X_train, X_test = X[train_idx], X[test_idx]
Y_train, Y_test = Y[train_idx], Y[test_idx]
model.fit(X_train, Y_train)
score = model.score(X_test, Y_test)
print(f"本折数据验证结果精确度: {score*100:.2f}%")
运行输出:
5 True 20209
第 1折:
训练集中的数据索引情况: [0 1 2 3 4]... (总数: 120)
验证集中的数据索引情况: [21 26 29 32 36]... (总数: 30)
验证集中的数据: [0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2]
训练集中的类别和数据分布数: [0 1 2] [40 40 40]
验证集中的类别和数据分布数: [0 1 2] [10 10 10]
本折数据验证结果精确度: 93.33%
第 2折:
训练集中的数据索引情况: [0 1 3 4 6]... (总数: 120)
验证集中的数据索引情况: [ 2 5 9 13 14]... (总数: 30)
验证集中的数据: [0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2]
训练集中的类别和数据分布数: [0 1 2] [40 40 40]
验证集中的类别和数据分布数: [0 1 2] [10 10 10]
本折数据验证结果精确度: 96.67%
第 3折:
训练集中的数据索引情况: [0 2 3 4 5]... (总数: 120)
验证集中的数据索引情况: [ 1 10 12 15 19]... (总数: 30)
验证集中的数据: [0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2]
训练集中的类别和数据分布数: [0 1 2] [40 40 40]
验证集中的类别和数据分布数: [0 1 2] [10 10 10]
本折数据验证结果精确度: 90.00%
第 4折:
训练集中的数据索引情况: [1 2 3 5 6]... (总数: 120)
验证集中的数据索引情况: [ 0 4 11 16 23]... (总数: 30)
验证集中的数据: [0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2]
训练集中的类别和数据分布数: [0 1 2] [40 40 40]
验证集中的类别和数据分布数: [0 1 2] [10 10 10]
本折数据验证结果精确度: 100.00%
第 5折:
训练集中的数据索引情况: [0 1 2 4 5]... (总数: 120)
验证集中的数据索引情况: [ 3 6 7 8 22]... (总数: 30)
验证集中的数据: [0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2]
训练集中的类别和数据分布数: [0 1 2] [40 40 40]
验证集中的类别和数据分布数: [0 1 2] [10 10 10]
本折数据验证结果精确度: 100.00%
请大家结合前面的介绍去理解这个案例,如果存在问题请评论去留言。
四、小结
本文介绍了机器学习的数据分割方法–分层K折交叉验证StratifiedKFold的原理以及在sklearn中的实现,并结合案例介绍了StratifiedKFold的使用。老猿认为StratifiedKFold这个方法在kfold的思想基础上,借鉴了train_test_split使用分类标签stratify的场景。
更多人工智能知识学习过程中可能遇到的疑难问题及解决办法请关注专栏《零基础机器学习入门》及付费专栏《机器学习疑难问题集》后续的文章。
写博不易,敬请支持:
如果阅读本文于您有所获,敬请点赞、评论、收藏,谢谢大家的支持!
关于老猿的付费专栏
- 付费专栏《https://blog.youkuaiyun.com/laoyuanpython/category_9607725.html 使用PyQt开发图形界面Python应用》专门介绍基于Python的PyQt图形界面开发基础教程,对应文章目录为《 https://blog.youkuaiyun.com/LaoYuanPython/article/details/107580932 使用PyQt开发图形界面Python应用专栏目录》;
- 付费专栏《https://blog.youkuaiyun.com/laoyuanpython/category_10232926.html moviepy音视频开发专栏 )详细介绍moviepy音视频剪辑合成处理的类相关方法及使用相关方法进行相关剪辑合成场景的处理,对应文章目录为《https://blog.youkuaiyun.com/LaoYuanPython/article/details/107574583 moviepy音视频开发专栏文章目录》;
- 付费专栏《https://blog.youkuaiyun.com/laoyuanpython/category_10581071.html OpenCV-Python初学者疑难问题集》为《https://blog.youkuaiyun.com/laoyuanpython/category_9979286.html OpenCV-Python图形图像处理 》的伴生专栏,是笔者对OpenCV-Python图形图像处理学习中遇到的一些问题个人感悟的整合,相关资料基本上都是老猿反复研究的成果,有助于OpenCV-Python初学者比较深入地理解OpenCV,对应文章目录为《https://blog.youkuaiyun.com/LaoYuanPython/article/details/109713407 OpenCV-Python初学者疑难问题集专栏目录 》
- 付费专栏《https://blog.youkuaiyun.com/laoyuanpython/category_10762553.html Python爬虫入门 》站在一个互联网前端开发小白的角度介绍爬虫开发应知应会内容,包括爬虫入门的基础知识,以及爬取优快云文章信息、博主信息、给文章点赞、评论等实战内容。
前两个专栏都适合有一定Python基础但无相关知识的小白读者学习,第三个专栏请大家结合《https://blog.youkuaiyun.com/laoyuanpython/category_9979286.html OpenCV-Python图形图像处理 》的学习使用。
对于缺乏Python基础的同仁,可以通过老猿的免费专栏《https://blog.youkuaiyun.com/laoyuanpython/category_9831699.html 专栏:Python基础教程目录)从零开始学习Python。
如果有兴趣也愿意支持老猿的读者,欢迎购买付费专栏。

638

被折叠的 条评论
为什么被折叠?



