[机器学习]三行代码快速划分交叉训练中训练集和验证集

本文介绍了一种使用numpy.random.choice()和set()快速划分训练集和验证集的方法,并展示了如何利用该方法进行批量训练。

使用numpy.random.choice()和set()快速划分交叉训练数据集

之前在划分训练集和验证集时,都是手工随机生成index,很笨。

学到的新方法如下:

import numpy as np
# 正态分布生成原始数据
x = np.random.random.normal(1,0.1,100)
# 按8:2分割数据
x_train_index = np.random.choice(len(x),round(len(x)*0.8),replace = False)
x_valid_index = np.array(list(set(range(len(x))) - set(x_train_index)))

x_train = x[x_train_index]
x_valid = x[x_valid_index]

总结1: np.random.choice()

Definition : choice(a, size=None, replace=True, p=None)

Type : Function of None module

Parameters
a : 1-D array-like or int
If an ndarray, a random sample is generated from its elements. If an int, the random sample is generated as if a was np.arange(n)
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., (m, n, k), then m * n * k samples are drawn. Default is None, in which case a single value is returned.
replace : boolean, optional
Whether the sample is with or without replacement
是否包含重复元素
p : 1-D array-like, optional
The probabilities associated with each entry in a. If not given the sample assumes a uniform distribution over all entries in a.
按什么概率分布选取元素,默认是均匀分布

Returns
samples : 1-D ndarray, shape (size,)
The generated random samples

总结2: set()

Python的集合(set)和其他语言类似, 是一个无序不重复元素集, 基本功能包括关系测试和消除重复元素.

总结3: batch training

batch training 一样可以使用这种方法选取数据

batch_size = 25
for epoch in range(100):
    rand_index = np.random.choice(len(x_train), size = batch_size)
    rand_x = x_train[rand_index]
    rand_y = y_train[rand_index]
    ...
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值