使用tf.data.Dataset.from_tensor_slices五步加载数据集

博主记录了使用TF2加载MNIST数据集的过程。思路包括准备numpy数据、加载、打乱、预处理、设置值及循环迭代等步骤。代码方面提到one - hot编码及shuffle函数数值等问题,总结强调五个步骤重要,还有其他加载方法后续再谈,建议读API。
部署运行你感兴趣的模型镜像

前言:

最近在学习tf2
数据加载感觉蛮方便的
这里记录下使用 tf.data.Dataset.from_tensor_slices 进行加载数据集.
使用tf2做mnist(kaggle)的代码

思路

Step0: 准备要加载的numpy数据
Step1: 使用 tf.data.Dataset.from_tensor_slices() 函数进行加载
Step2: 使用 shuffle() 打乱数据
Step3: 使用 map() 函数进行预处理
Step4: 使用 batch() 函数设置 batch size
Step5: 根据需要 使用 repeat() 设置是否循环迭代数据集

代码

import tensorflow as tf
from tensorflow import keras

def load_dataset():
	# Step0 准备数据集, 可以是自己动手丰衣足食, 也可以从 tf.keras.datasets 加载需要的数据集(获取到的是numpy数据) 
	# 这里以 mnist 为例
	(x, y), (x_test, y_test) = keras.datasets.mnist.load_data()
	
	# Step1 使用 tf.data.Dataset.from_tensor_slices 进行加载
	db_train = tf.data.Dataset.from_tensor_slices((x, y))
	db_test = tf.data.Dataset.from_tensor_slices((x_test, y_test))
	
	# Step2 打乱数据
	db_train.shuffle(1000)
	db_test.shuffle(1000)
	
	# Step3 预处理 (预处理函数在下面)
	db_train.map(preprocess)
	db_test.map(preprocess)

	# Step4 设置 batch size 一次喂入64个数据
	db_train.batch(64)
	db_test.batch(64)

	# Step5 设置迭代次数(迭代2次) test数据集不需要emmm
	db_train.repeat(2)

	return db_train, db_test

def preprocess(labels, images):
	'''
	最简单的预处理函数:
		转numpy为Tensor、分类问题需要处理label为one_hot编码、处理训练数据
	'''
	# 把numpy数据转为Tensor
	labels = tf.cast(labels, dtype=tf.int32)
	# labels 转为one_hot编码
	labels = tf.one_hot(labels, depth=10)
	# 顺手归一化
	images = tf.cast(images, dtype=tf.float32) / 255
	return labels, images
  1. one_hot 编码: 小姐姐给你解释去 (我在使用自带的fit函数进行训练的时候,发现报错维度不正确,原来是不需要one_hot编码)

  2. shuffle()函数的数值: 源码链接, 内容我贴图了
    函数定义源码
    我找到一个比较好的解释: 简书真是好东西

  3. 我发现 自己的数据使用tf.data.Dataset.from_tensor_slices(x, y)加载时, 一定要x在前y在后。。。没仔细看函数说明,否则会导致bug的emmm

  4. 使用了该函数之后, fit的时候是不支持 validation_split 这个参数提供的功能的~

总结

五个步骤很重要 比较简单的方式加载数据 当然还有其他方法加载 之后再说叭
此外, 建议读读api tf.data.Dataset 里好东西太多了~

您可能感兴趣的与本文相关的镜像

TensorFlow-v2.15

TensorFlow-v2.15

TensorFlow

TensorFlow 是由Google Brain 团队开发的开源机器学习框架,广泛应用于深度学习研究和生产环境。 它提供了一个灵活的平台,用于构建和训练各种机器学习模型

<think>我们被要求介绍tf.data.Dataset.from_tensor_slices使用方法和示例,并且有多个引用可供参考。 根据引用[1]和[2],我们知道这个函数可以创建一个由输入tensor的切片构成的Dataset,它会沿着输入tensor的第一个维度进行切片,并保留输入tensors的结构。所有输入tensor的第一个维度必须相同。 引用[3]提供了一个简单的示例:将一个一维张量(通过[:, None]变成二维,但每个元素是一个长度为1的向量)传入,然后创建Dataset,然后迭代输出每个切片。 我们将按照以下结构回答: 1. 函数的作用和原理 2. 使用方法的详细说明(参数、返回值等) 3. 示例(多个示例,展示不同输入情况) 注意:由于引用[3]的示例是TensorFlow 2.1的,我们同样使用TensorFlow 2.x版本的语法。 另外,根据系统指令,数学公式使用$...$,独立公式使用$$...$$,但这里可能公式不多,注意行内公式格式。同时,最后生成相关问题。 我们开始: 函数签名:tf.data.Dataset.from_tensor_slices(tensors) 作用:将输入的张量(或元组、字典等)沿第一个维度切片,返回一个Dataset对象,其中每个元素是对应位置的切片。 注意: - 输入可以是单个张量、元组、字典或嵌套结构。 - 所有输入张量的第一个维度大小必须相同。 示例1:单个张量(一维) 输入一个一维张量,例如[1,2,3,4,5],则数据集包含5个元素,分别是1,2,3,4,5。 示例2:单个张量(二维) 输入一个二维张量,例如[[1,2],[3,4],[5,6]],则数据集包含3个元素,每个元素是一维张量:[1,2]、[3,4]、[5,6]。 示例3:多个张量(元组) 输入两个一维张量(第一个维度相同),例如([1,2,3], [4,5,6]),则数据集包含3个元素,每个元素是一个元组:(1,4), (2,5), (3,6) 示例4:字典 输入一个字典,例如{'a': [1,2], 'b': [3,4]},则数据集包含2个元素,每个元素是字典:{'a':1, 'b':3} 和 {'a':2, 'b':4} 我们将编写代码示例,并说明输出。 注意:由于在文本中,我们使用代码块来展示代码,并注释输出。 另外,我们还可以结合其他方法如batch, repeat等,但用户只问from_tensor_slices,所以我们主要展示这个函数,其他方法可以简单提及。 最后,生成相关问题。 开始写回答: </think> ### `tf.data.Dataset.from_tensor_slices` 的使用方法与示例 `tf.data.Dataset.from_tensor_slices` 是 TensorFlow 中用于构建数据集的核心 API,它通过**沿输入张量的第一个维度切片**生成数据集。以下详细说明其原理和使用方法: --- #### **功能原理** 1. **切片机制**: 将输入张量(或嵌套结构)沿第 0 维(最外层维度)切片,生成多个子元素。 例如: - 输入张量形状 `(N, ...)` → 输出 `N` 个形状为 `(...)` 的元素 - 输入多个张量时,**所有张量的第 0 维大小必须相同**,系统自动对齐切片[^2]。 2. **支持数据类型**: 支持单个张量、元组、字典或嵌套结构(如 `{'feature': tensor1, 'label': tensor2}`)。 --- #### **基础用法** ##### 示例 1:单个张量输入 ```python import tensorflow as tf # 创建形状为 (4,) 的一维张量 data = tf.constant([10, 20, 30, 40]) dataset = tf.data.Dataset.from_tensor_slices(data) # 迭代输出 for item in dataset: print(item.numpy()) # 输出: 10, 20, 30, 40(4个标量元素) ``` ##### 示例 2:二维张量输入 ```python # 创建形状为 (3, 2) 的二维张量 matrix = tf.constant([[1, 2], [3, 4], [5, 6]]) dataset = tf.data.Dataset.from_tensor_slices(matrix) for item in dataset: print(item.numpy()) # 输出: [1 2], [3 4], [5 6](3个一维张量) ``` --- #### **高级用法** ##### 示例 3:元组输入(特征+标签) ```python features = tf.constant([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]) # 形状 (3, 2) labels = tf.constant([0, 1, 0]) # 形状 (3,) # 传入元组 (features, labels) dataset = tf.data.Dataset.from_tensor_slices((features, labels)) for feat, label in dataset: print(f"Feature: {feat.numpy()}, Label: {label.numpy()}") # 输出: # Feature: [0.1 0.2], Label: 0 # Feature: [0.3 0.4], Label: 1 # Feature: [0.5 0.6], Label: 0 ``` ##### 示例 4:字典输入 ```python data_dict = { 'image': tf.ones(shape=(2, 28, 28)), # 228x28图片 'label': tf.constant([3, 7]) } dataset = tf.data.Dataset.from_tensor_slices(data_dict) for sample in dataset: print(sample['image'].shape, sample['label'].numpy()) # 输出: # (28, 28) 3 (第一组切片) # (28, 28) 7 (第二组切片) ``` --- #### **注意事项** 1. **内存限制**: 输入数据会**整体加载到内存**,适用于中小型数据集。大型数据集使用 `tf.data.TFRecordDataset`[^1]。 2. **维度对齐**: 多输入时,所有张量的第 0 维大小必须一致,否则报错 `ValueError`[^2]。 3. **典型组合**: 常与以下方法链式调用: - `.batch(32)`:将数据分批次 - `.repeat(epochs)`:重复数据集 - `.shuffle(buffer_size)`:打乱顺序 ```python dataset = tf.data.Dataset.from_tensor_slices(data) .shuffle(100).batch(32).repeat(2) # 打乱→分批→重复2轮 ``` ---
评论 23
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值