Tensorflow2 的常用函数11-20
11. tf.data.Dataset.from_tensor_slices()数据集切片
该函数是dataset核心函数之一,它的作用是把给定的元组、列表和张量等数据进行特征切片。切片的范围是从最外层维度开始的。如果有多个特征进行组合,那么一次切片是把每个组合的最外维度的数据切开,分成一组一组的。
import tensorflow as tf
features = tf.constant([12, 23, 10, 17])
labels = tf.constant([0, 1, 1, 0])
dataset = tf.data.Dataset.from_tensor_slices((features, labels))
for element in dataset:
print(element)
结果:
(<tf.Tensor: shape=(), dtype=int32, numpy=12>, <tf.Tensor: shape=(), dtype=int32, numpy=0>)
(<tf.Tensor: shape=(), dtype=int32, numpy=23>, <tf.Tensor: shape=(), dtype=int32, numpy=1>)
(<tf.Tensor: shape=(), dtype=int32, numpy=10>, <tf.Tensor: shape=(), dtype=int32, numpy=1>)
(<tf.Tensor: shape=(), dtype=int32, numpy=17>, <tf.Tensor: shape=(), dtype=int32, numpy=0>)
也就是说,12-0 23-1 10-1 17-0 分别对应。在制作和使用数据集的时候,常常用此函数将特征标签进行组合。
12. tf.GradientTape() 求导
梯度下降求导的核心函数。 这里
import tensorflow as tf

最低0.47元/天 解锁文章
956





