概述
tf.data.Dataset.from_tensor_slices作用于切分传入Tensor的第一个维度。生成相应的dataset。
用法
1.传入的数据为矩阵,假如它的形状为(6,3) ,tf.data.Dataset.from_tensor_slices会将其切分矩阵的第一维度,最后生成的dataset含有6个元素,每个元素的形状为(3, ),即每个元素是矩阵的一行。
import tensorflow as tf
import numpy as np
dataset = tf.data.Dataset.from_tensor_slices(np.random.uniform(size=(6,3)))
for data in enumerate(dataset):
print(data,"****")
print("-------")
OUT:
2.对于元素为字典或者是元组(例如:在图像识别里面的一个元素可以是{“image”:“image_tensor”, “label”:“label_tensor”})
import tensorflow as tf
import numpy as np
dataset = tf.data.Dataset.from_tensor_slices({'image':np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]), 'label':np.random.uniform(size=(6,3))})
dataset
OUT:
函数会分别切分"image"和"label"的数值,切分后的dataset中的元素的形式类似于{“image”:1.0, “b”:[0.9,0.1]}这样的形式。
enumerate的用法请参考:https://blog.youkuaiyun.com/silent1cat/article/details/119647131
uniform的用法请参考:https://blog.youkuaiyun.com/silent1cat/article/details/119750523
希望这篇文章对大家的学习有所帮助!