tensorflow使用tf.data.Dataset 处理大型数据集

本文介绍了如何利用tensorflow的tf.data.Dataset模块处理大型数据集,避免一次性加载内存。主要内容包括:构建生成器函数,使用from_generator创建Dataset,指定batch_size,以及通过make_one_shot_iterator获取数据。在生成器中,利用yield生成每批次数据,确保yield返回的shape与Dataset定义一致。最后强调了get_next()在迭代器工作中的重要性。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

最近深度学习用到的数据集比较大,如果一次性将数据集读入内存,那服务器是顶不住的,所以需要分批进行读取,这里就用到了tf.data.Dataset构建数据集,先看一个博文,入入门:

https://www.jianshu.com/p/f580f4fc2ba0

概括一下,tf.data.Dataset主要有几个部分最重要:

  1. 构建生成器函数
  2. 使用tf.data.Dataset的from_generator函数,通过指定数据类型,数据的shape等参数,构建一个Dataset
  3. 指定batch_size
  4. 使用make_one_shot_iterator()函数,构建一个iterator
  5. 使用上面构建的迭代器开始get_next() 。(必须要有这个get_next(),迭代器才会工作)

一.构建生成器

生成器的要点是要在while True中加入yield,yield的功能有点类似return,有yield才能起到迭代的作用。
我的数据是一个[6047, 6000, 1]的文本数据,我每次迭代返回的shape为[1,6000,1],要注意的是返回的shape要和构建Dataset时的shape一致,下面会说到。代码如下:

def gen():				
		train=pd.read_csv('/home/chenqiren/PycharmProjects/code/test/formal/small_sample/train2.csv', header=None)
        train.fillna(0, inplace = True)
        label_encoder = LabelEncoder().fit(train[6000])
        label = label_encoder.transform(train[6000])  
        train = train.drop([6000], axis=1) 
        scaler = StandardScaler().fit(train.values)   #train.values中的值是csv文件中的那些值,     这步标准化可以保留
        scaled_train = scaler.transform(train.values)
        #print(scaled_train)
        #拆分训练集和测试集--------------
        sss=StratifiedShuffleSplit(test_size=0.1, random_state=23)
        for train_index, valid_index in sss.split(scaled_train, label):   #需要的是数组,train.values得到的是数组
            X_train, X_valid=scaled_train[train_index], scaled_train[valid_index]  #https://www.cnblogs.com/Allen-rg/p/9453949.html
            y_train, y_valid=label[train_index], label[valid_index]
        X_train_r=np.zeros((len(X_train), 6000, 1))   #先构建一个框架出来,下面再赋值
        X_train_r[:,: ,0]=X_train[:,0:6000]     
    
        X_valid_r=np.zeros((len(X_valid), 6000, 1))
        X_valid_r[:,: ,0]=X_valid[:,0:6000]
    
        y_train=np_utils.to_categorical(y_train, 3)
        y_valid=np_utils.to_categorical(y_valid, 3)
        
        leng=len(X_train_r)
        index=0
        while True:
            x_train_batch=X_train_r[index, :, 0:1]
            y_train_batch=y_train[index, :]
            yield (x_train_batch, y_train_batch)
            index=index+1
            if index>leng:
                break

代码中while True上面的部分是标准化数据的代码,可以不用看,只需要看 while True中的代码即可。x_train_batch, y_train_batch都只是一行的数据,这里是一行一行数据迭代。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值