Sklearn : train_test_split()函数的用法

本文介绍了sklearn库中的train_test_split函数,详细解释了其参数作用,包括X和y代表的数据集和标签,train_size和test_size用于指定划分比例,random_state控制随机性,shuffle决定是否打乱数据,以及stratify保持标签分布均衡。通过实例展示了不同参数设置下的数据划分效果。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

一、train_test_split官方文档链接

二、参数介绍
① X :(必需) 待划分的样本集
② y :(非必需) 样本标签target(如果你只是想把数据简单的分为两部分,不涉及分类算法等需要标注数据标签的情况就无须设置)
③ train_size : (非必需) int型或float型,整型表示划分后的数据个数;浮点型表示划分数据的比例。
④ test_size :(非必需) 同上
⑤ random_state :(非必需) int 类型,默认值为None。先笼统的认为是一个控制分裂过程随机性的一个参数。不用管内部实现过程。
⑥ shuffle :(非必需) 默认为True。控制拆分数据前,原始数据集是否需要打乱再拆分。
⑦ stratify :(非必需)

三、自己动手看一下效果
1. 自定义一个数据集:

import numpy as np

x = np.arange(1, 25).reshape(12, 2)
y = np.array([0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0])

print("原始数据集:\n",x)
print("数据标签:\n",y)

输出如下:

原始数据集:
 [[ 1  2]
 [ 3  4]
 [ 5  6]
 [ 7  8]
 [ 9 10]
 [11 12]
 [13 14]
 [15 16]
 [17 18]
 [19 20]
 [21 22]
 [23 24]]
数据标签:
 [0 1 1 0 1 0 0 1 1 0 1 0]

2.先用最简单的执行方式,了解一下各个参数的作用
首先默认一切非必须参数为默认值:

from  sklearn.model_selection import train_test_split
x_train,x_test = train_test_split(x)
print("训练数据为:\n",x_train)
print("测试数据为:\n",x_test)

输出如下:

训练数据为:
 [[ 9 10]
 [13 14]
 [ 5  6]
 [11 12]
 [15 16]
 [21 22]
 [23 24]
 [ 7  8]
 [ 1  2]]
测试数据为:
 [[17 18]
 [19 20]
 [ 3  4]]

由于shuffle默认值为True,即原始数据在划分前做了打乱处理,因此分裂后的数据是乱序的。若此时重新执行当前代码,输出结果如下:

训练数据为:
 [[ 7  8]
 [23 24]
 [ 1  2]
 [ 3  4]
 [19 20]
 [15 16]
 [21 22]
 [17 18]
 [13 14]]
测试数据为:
 [[11 12]
 [ 5  6]
 [ 9 10]]

可以发现,输出变了,(emmm,大部分实验场景下,这种每次都变的输入并不利于后续的研究分析,应该避免,但是也不排除有些情况下,就是希望数据每次都不一样,OK,你自己判断一下。我不管。。。)。若要避免这种情况有两种解决方式,分别是设置shuffle和random_state,详情见3、4。

  1. 设置shuffle
    设置shuffle为False,即在拆分数据前,不对原始数据集进行打乱,但是划分结果后的数据按原始数据的顺序排列。
x_train,x_test = train_test_split(x,shuffle=False)

输出结果如下:

训练数据为:
 [[ 1  2]
 [ 3  4]
 [ 5  6]
 [ 7  8]
 [ 9 10]
 [11 12]
 [13 14]
 [15 16]
 [17 18]]
测试数据为:
 [[19 20]
 [21 22]
 [23 24]]
  1. 设置random_state
    若你仍希望拆分前对原始数据集进行打乱处理,又希望每次重复执行你的代码时,划分后的数据能够相同,此时就可以设置rrandom_state,至于设置为多少我还没有搞懂,根据经验在1到10随便设置选一个整数就OK,结果没有影响。
x_train,x_test = train_test_split(x,random_state=1)

重复执行两次上述代码,划分后的数据集两次相同,输出结果如下:

训练数据为:
 [[21 22]
 [ 3  4]
 [13 14]
 [ 1  2]
 [15 16]
 [23 24]
 [19 20]
 [17 18]
 [11 12]]
测试数据为:
 [[ 5  6]
 [ 7  8]
 [ 9 10]]
  1. 设置stratify
    如果你的数据压根没有标签,就不用管这个参数;如果有标签,假设标签分别为0和1,表示正负样本,设置stratify就可以保证划分后的训练集和测试集中正负样本的比例近似相同。
x = np.arange(1, 25).reshape(12, 2)
y = np.array([0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0])
x_train, x_test, y_train, y_test = train_test_split(x, y, stratify=y)

输出结果如下:

数据标签:
 [0 1 1 0 1 0 0 1 1 0 1 0]
训练数据为:
 [[21 22]
 [ 3  4]
 [ 7  8]
 [17 18]
 [15 16]
 [ 1  2]
 [ 9 10]
 [13 14]
 [11 12]] 
训练数据标签为:
 [1 1 0 1 1 0 1 0 0]  #(emmm这里样本总个数太少了,不明显,你可以自己试一下大样本的结果)正负样本比例4/5
测试数据为:
 [[23 24]
 [ 5  6]
 [19 20]] 
测试数据标签为:        #正负样本比例2/1
 [0 1 0]
### 正确导入 `train_test_split` 模块的方法 在 Python 中,如果尝试通过 `from sklearn.cross_validation import train_test_split` 导入模块时会引发错误,这是因为自 scikit-learn 版本 0.20 起,`cross_validation` 已被弃用并替换为 `model_selection` 模块[^1]。因此,应改为使用以下语句: ```python from sklearn.model_selection import train_test_split ``` 然而,在某些情况下即使执行上述代码仍可能触发 ImportError 错误。这通常是因为当前环境中未安装最新版本的 scikit-learn 或者存在多个不同版本冲突的情况[^4]。 #### 解决方案一:更新 scikit-learn 库至最新版 为了确保能够成功调用 `train_test_split` 函数,建议先升级本地环境中的 scikit-learn 到最新稳定版本。可以运行如下命令完成操作: ```bash pip install --upgrade scikit-learn ``` 或者如果你正在使用 conda,则可以通过下面这条指令来实现相同目的: ```bash conda update scikit-learn ``` #### 解决方案二:检查 IDE 配置与依赖管理工具设置 当开发人员报告即便已经正确安装了所需库仍然遭遇同样的问题时,可能是由于集成开发环境 (IDE) 如 PyCharm 的内部虚拟环境配置不当所致[^5]。此时需确认所使用的解释器确实包含了最新的 scikit-learn 包。具体做法是在 PyCharm 设置界面里找到对应项目的 Interpreter Options ,然后手动添加缺少的包或重新初始化整个工作区内的外部库列表。 综上所述,要彻底消除此类 ImportErrors 并顺利加载 `train_test_split` 方法,除了修正语法外还需要关注实际部署场景下的软件生态链状态以及合理调整相关参数选项。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值