Cross_validation.train_test_split 中 stratify这个参数的意义是什么?

本文详细解释了在使用sklearn库进行数据集划分时,如何利用stratify参数确保训练集与测试集之间的类别比例一致,这对于处理类别不平衡的数据尤为重要。

比单独使用train_test_split来划分数据更严谨

 

 

stratify是为了保持split前类的分布。比如有100个数据,80个属于A类,20个属于B类。如果train_test_split(... test_size=0.25, stratify = y_all), 那么split之后数据如下: 
training: 75个数据,其中60个属于A类,15个属于B类。 
testing: 25个数据,其中20个属于A类,5个属于B类。 
用了stratify参数,training集和testing集的类的比例是 A:B= 4:1,等同于split前的比例(80:20)。通常在这种类分布不平衡的情况下会用到stratify。

这个参数sklearn的文档4中讲的不是太清楚

帮助文档

http://scikit-learn.org/stable/modules/generated/sklearn.cross_validation.train_test_split.html

``` import os import pandas as pd import numpy as np from sklearn.model_selection import train_test_split # 加载函数保持不变 def processTarget(): main_folder = 'C:/Users/Lenovo/Desktop/crcw不同端12k在0负载下/风扇端' data_list = [] label_list = [] for folder_name in sorted(os.listdir(main_folder)): folder_path = os.path.join(main_folder, folder_name) if os.path.isdir(folder_path): csv_files = [f for f in os.listdir(folder_path) if f.endswith('.csv')] print(f"Processing folder: {folder_name}, found {len(csv_files)} CSV files.") for filename in sorted(csv_files): file_path = os.path.join(folder_path, filename) csv_data = pd.read_csv(file_path, header=None) if csv_data.shape[1] >= 4: csv_data = csv_data.iloc[:, [0, 1, 2]].values else: print(f"Skipping file {filename}, unexpected shape: {csv_data.shape}") continue data_list.append(csv_data) if '内圈故障' in folder_name: class_label = 0 elif '球故障' in folder_name: class_label = 1 else: continue label_list.append(class_label) if data_list and label_list: data = np.array(data_list) # Shape: (num_samples, seq_length, num_features) labels = np.array(label_list) # Shape: (num_samples,) return data, labels else: raise ValueError("No valid data available to process.") # 划分数据集 def split_datasets(X, y, test_size=0.2, val_size=0.25): """ :param X: 特征数据数组 :param y: 标签数组 :param test_size: 测试集占比,默认值为 0.2(即 80% 训练 + 验证) :param val_size: 验证集占剩余训练数据的比例,默认值为 0.25 """ X_train_val, X_test, y_train_val, y_test = train_test_split( X, y, test_size=test_size, stratify=y, random_state=42 ) # 继续从剩下的数据中切出 validation set X_train, X_val, y_train, y_val = train_test_split( X_train_val, y_train_val, test_size=val_size, stratify=y_train_val, random_state=42 ) return X_train, X_val, X_test, y_train, y_val, y_test if __name__ == "__main__": try: data0, label0 = processTarget() # 分割成训练集、验证集和测试集 X_train, X_val, X_test, y_train, y_val, y_test = split_datasets(data0, label0) print("Training Set:", X_train.shape, y_train.shape) print("Validation Set:", X_val.shape, y_val.shape) print("Testing Set:", X_test.shape, y_test.shape) # 存储结果以便后续步骤使用 np.savez('datasets.npz', X_train=X_train, y_train=y_train, X_val=X_val, y_val=y_val, X_test=X_test, y_test=y_test) except ValueError as e: print(e)```这是我将数据集划分训练集,测试集,验证集的代码,现在,我要在这个代码的基础上对该数据集运用DEEP DOMAIN CONFUSION进行处理,可以给出完整的代码吗?要求:划分数据集和DEEP DOMAIN CONFUSION分为两个不同的文件
03-31
ValueError Traceback (most recent call last) Cell In[22], line 1 ----> 1 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) 2 dtrain = xgb.DMatrix(X_train, label=y_train) 3 dtest = xgb.DMatrix(X_test, label=y_test) File ~\anaconda3\Lib\site-packages\sklearn\utils\_param_validation.py:216, in validate_params.<locals>.decorator.<locals>.wrapper(*args, **kwargs) 210 try: 211 with config_context( 212 skip_parameter_validation=( 213 prefer_skip_nested_validation or global_skip_validation 214 ) 215 ): --> 216 return func(*args, **kwargs) 217 except InvalidParameterError as e: 218 # When the function is just a wrapper around an estimator, we allow 219 # the function to delegate validation to the estimator, but we replace 220 # the name of the estimator by the name of the function in the error 221 # message to avoid confusion. 222 msg = re.sub( 223 r"parameter of \w+ must be", 224 f"parameter of {func.__qualname__} must be", 225 str(e), 226 ) File ~\anaconda3\Lib\site-packages\sklearn\model_selection\_split.py:2848, in train_test_split(test_size, train_size, random_state, shuffle, stratify, *arrays) 2845 if n_arrays == 0: 2846 raise ValueError("At least one array required as input") -> 2848 arrays = indexable(*arrays) 2850 n_samples = _num_samples(arrays[0]) 2851 n_train, n_test = _validate_shuffle_split( 2852 n_samples, test_size, train_size, default_test_size=0.25 2853 ) File ~\anaconda3\Lib\site-packages\sklearn\utils\validation.py:532, in indexable(*iterables) 502 """Make arrays indexable for cross-validation. 503 504 Checks consistent length, passes through None, and ensures that everything (...) 528 [[1, 2, 3], array([2, 3, 4]), None, <...Sparse...dtype 'int64'...shape (3, 1)>] 529 """ 531 result = [_make_indexable(X) for X in iterables] --> 532 check_consistent_length(*result) 533 return result File ~\anaconda3\Lib\site-packages\sklearn\utils\validation.py:475, in check_consistent_length(*arrays) 473 uniques = np.unique(lengths) 474 if len(uniques) > 1: --> 475 raise ValueError( 476 "Found input variables with inconsistent numbers of samples: %r" 477 % [int(l) for l in lengths] 478 ) ValueError: Found input variables with inconsistent numbers of samples: [4, 2]
09-18
-------------------------------------------------------------------------- ValueError Traceback (most recent call last) Cell In[76], line 2 1 from sklearn.model_selection import train_test_split ----> 2 features_train, features_test, labels_train, labels_test = train_test_split(features,labels,test_size=0.3, random_state=0) 3 print("训练集Feature的矩阵规模:", features_train.shape) 4 print("训练集Tag的矩阵规模:", labels_train.shape) File D:\Anaconda\Lib\site-packages\sklearn\utils\_param_validation.py:216, in validate_params.<locals>.decorator.<locals>.wrapper(*args, **kwargs) 210 try: 211 with config_context( 212 skip_parameter_validation=( 213 prefer_skip_nested_validation or global_skip_validation 214 ) 215 ): --> 216 return func(*args, **kwargs) 217 except InvalidParameterError as e: 218 # When the function is just a wrapper around an estimator, we allow 219 # the function to delegate validation to the estimator, but we replace 220 # the name of the estimator by the name of the function in the error 221 # message to avoid confusion. 222 msg = re.sub( 223 r"parameter of \w+ must be", 224 f"parameter of {func.__qualname__} must be", 225 str(e), 226 ) File D:\Anaconda\Lib\site-packages\sklearn\model_selection\_split.py:2848, in train_test_split(test_size, train_size, random_state, shuffle, stratify, *arrays) 2845 if n_arrays == 0: 2846 raise ValueError("At least one array required as input") -> 2848 arrays = indexable(*arrays) 2850 n_samples = _num_samples(arrays[0]) 2851 n_train, n_test = _validate_shuffle_split( 2852 n_samples, test_size, train_size, default_test_size=0.25 2853 ) File D:\Anaconda\Lib\site-packages\sklearn\utils\validation.py:532, in indexable(*iterables) 502 """Make arrays indexable for cross-validation. 503 504 Checks consistent length, passes through None, and ensures that everything (...) 528 [[1, 2, 3], array([2, 3, 4]), None, <...Sparse...dtype 'int64'...shape (3, 1)>] 529 """ 531 result = [_make_indexable(X) for X in iterables] --> 532 check_consistent_length(*result) 533 return result File D:\Anaconda\Lib\site-packages\sklearn\utils\validation.py:475, in check_consistent_length(*arrays) 473 uniques = np.unique(lengths) 474 if len(uniques) > 1: --> 475 raise ValueError( 476 "Found input variables with inconsistent numbers of samples: %r" 477 % [int(l) for l in lengths] 478 ) ValueError: Found input variables with inconsistent numbers of samples: [4, 490129]
07-03
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值