tf.data.Dataset.from_tensor_slices()
语义解释:from_tensor_slices,从张量的切片读取数据。
工作原理:将输入的张量的第一个维度看做样本的个数,沿其第一个维度将tensor切片,得到的每个切片是一个样本数据。实现了输入张量的自动切片。
输入数据格式/要求:
1)可以是numpy格式,也可以是tensorflow的tensor的格式,函数会自动将numpy格式转为tensorflow的tensor格式
2)输入可以是一个tensor
或 一个tensor字典(字典的每个key对应的value是一个tensor,要求各tensor的第一个维度相等)
或 一个tensor tuple(tuple 的每个元素是一个tensor,要求各tensor的第一个维度相等)。
示例代码:
import tensorflow as tf
import numpy as np
## 测试1: 输入是一个 tensor,函数将样本个数识别为8,然后对张量切片,每个样本的维度是(100)
dataset_tensor = tf.data.Dataset.from_tensor_slices(tf.random_uniform([8,100]))
print("dataset_tensor.output_shapes = ",dataset_tensor.output_shapes)
## 测试2: 输入是一个 numpy
dataset_numpy = tf.data.Dataset.from_tensor_slices(np.random.randn(8,100))
print("dataset_numpy.output_shapes = ",dataset_numpy.output_shapes)
## 测试3: 输入是一个 dict:当不同的value-tensor的第一个维度不同时,会报错,无法对各张量统一切片
dataset_dict = tf.data.Dataset.from_tensor_slices(
{"a":tf.random_uniform([8,100]),
"b":tf.random_uniform([8,1000])})
print("dataset_dict.output_shapes = ",dataset_dict.output_shapes)
## 测试4: 输入是一个 tuple:当不同的 tensor元素的第一个维度不同时,会报错,无法对各张量统一切片
dataset_tuple = tf.data.Dataset.from_tensor_slices(
(tf.ones([8,10]),tf.zeros([8,100]),tf.random_uniform([8,15,100]))
)
print("dataset_tuple.output_shapes = ",dataset_tuple.output_shapes)
本文介绍了tf.data.Dataset.from_tensor_slices()的语义、工作原理、输入数据格式及要求。该函数从张量的切片读取数据,将输入张量沿第一个维度切片,实现自动切片。输入数据可以是numpy或tensorflow的tensor格式,还可以是tensor、tensor字典或tensor tuple。
3723

被折叠的 条评论
为什么被折叠?



