sklearn.utils.shuffle解析

使用sklearn.shuffle打乱样本
本文介绍如何使用Python的sklearn.utils.shuffle函数打乱样本数据。此函数提供了一种简便的方法来对数组或稀疏矩阵进行一致的随机排列。文章详细解析了函数的工作原理及其参数设置。
部署运行你感兴趣的模型镜像

在进行机器学习时,经常需要打乱样本,这种时候Python中叒有第三方库提供了这个功能——sklearn.utils.shuffle。

Shuffle arrays or sparse matrices in a consistent way. This is a convenience alias to resample(*arrays, replace=False) to do random permutations of the collections.

函数参数

Parameters

参数介绍
*array带索引的序列,可以是arrays, lists, dataframes或scipy sparse matrices
random_stateint,随机量,就是一个random seed。如果是int,该参数作为random seed的值;如果是None,随机生成器就是一个np.random实例
n_sampleint,默认为None,输出的样本数目。如果是空,则样本数目会设置为array的第一维元素数

Returns

参数介绍
shuffled_arrays带索引的序列,是一个view(也就是说不会改变输入array)

Examples

这里写图片描述

解释:例程中建立了3个带索引的序列:array, array和sparse matrix。然后将它们作为一个元组进行shuffle,其中random_state=0表示它们的打乱方式是方式0。这个打乱方式不理解的可以看一下np.random.seed的介绍或者是看我接下来对源码的解析。

源码

shuffle

def shuffle(*arrays, **options):
    options['replace'] = False
    return resample(*arrays, **options)

Are you kidding? 这是个“空壳函数”。唯一的作用就是将一个参数replace置为了False,好让shuffle过程中不影响输入array(不过要记住这个replace,这是sklearn.utils.shufflesklearn.utils.resample唯一的区别)。

那么下面来看resample函数。

resample

def resample(*arrays, **options):
    '''先是类型检测部分,可以跳过'''
    random_state = check_random_state(options.pop('random_state', None))  # 此处注意:返回类型变了,变成:np.random.mtrand._rand或np.random.RandomState(seed)或seed
    replace = option.pop('replace', True)  # 如果没有‘replace’则返回True
    max_n_samples = options.pop('n_samples', None)
    if options:
        raise ValueError("Unexpected kw arguments: %r" % options.keys())

    if len(arrays) == 0:
    return None

    first = arrays[0]
    n_samples = first.shape[0] if hasattr(first, 'shape') else len(first)

    if max_n_samples is None:
        max_n_samples = n_samples
    elif (max_n_samples > n_samples) and (not replace):
        raise ValueError("Cannot sample %d out of arrays with dim %d when replace is False" % (max_n_samples, n_samples))
    check_consistent_length(*array)
    '''开始正文'''
    '''重排索引'''
    if replace:
        indices = random_state.randint(0, n_samples, size=(max_n_samples,))  # 创建新的随机序列索引indices
    else:
        indices = np.arange(n_samples)
        random_state.shuffle(indices)
        indices = indices[:max_n_samples]

    # convert sparse matrices to CSR for row-based indexing
    arrays = [a.tocsr() if issparse(a) else a for a in arrays]
    '''根据indices对arrays进行采样'''
    resampled_arrays = [safe_indexing(a, indices) for a in arrays]
    '''分两种情况,一种是输入的*arrays参数只有一个序列,另一种是输入的*arrays参数是一元组的序列'''
    if len(resampled_arrays) == 1:
        # syntactic sugar for the unit argument case
        return resampled_arrays[0]
    else:
        return resampled_arrays

解释:代码分为三部分:

  1. 类型检测
  2. 构建重排后的索引
  3. 根据索引输出序列

random_state在函数中用于产生随机索引

您可能感兴趣的与本文相关的镜像

Stable-Diffusion-3.5

Stable-Diffusion-3.5

图片生成
Stable-Diffusion

Stable Diffusion 3.5 (SD 3.5) 是由 Stability AI 推出的新一代文本到图像生成模型,相比 3.0 版本,它提升了图像质量、运行速度和硬件效率

utils.py:import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable from torch.optim.lr_scheduler import StepLR from torch.utils.data import DataLoader, Dataset from torch.utils.data.sampler import Sampler import numpy as np # import RandomErasing import os import math import argparse import scipy as sp import scipy.stats import pickle import random import scipy.io as sio from sklearn.decomposition import PCA from sklearn import metrics import matplotlib.pyplot as plt from scipy.io import loadmat from sklearn import preprocessing from sklearn.neighbors import KNeighborsClassifier from matplotlib import pyplot import torchnet as tnt from torch.utils.data.dataloader import default_collate # 自定义测试数据集 import torch.utils.data as data import torchvision.transforms as transforms def radiation_noise(data, alpha_range=(0.9, 1.1), beta=1/25): # 0。9 1。1 alpha = np.random.uniform(*alpha_range) noise = np.random.normal(loc=0., scale=1.0, size=data.shape) return alpha * data + beta * noise def flip_augmentation(data): # arrays tuple 0:(7, 7, 103) 1=(7, 7) horizontal = np.random.random() > 0.5 # True vertical = np.random.random() > 0.5 # False if horizontal: data = np.fliplr(data) if vertical: data = np.flipud(data) return data class hsidataset_target(data.Dataset): def __init__(self, data,label): self.data=data self.label=label # self.trans=RandomErasing.RandomErasing() def __getitem__(self, index): img=self.data[index] img1=img return img,img1 def __len__(self): return len(self.data) def twist_loss(p1,p2,alpha=1,beta=1): eps=1e-7 #ensure calculate #eps=0 kl_div=((p2*p2.log()).sum(dim=1)-(p2*p1.log()).sum(dim=1)).mean() mean_entropy=-(p1*(p1.log()+eps)).sum(dim=1).mean() mean_prob=p1.mean(dim=0) entropy_mean=-(mean_prob*(mean_prob.log()+eps)).sum() return kl_div+alpha*mean_entropy-beta*entropy_mean
06-25
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值