TensorFlow 中张量切分操作 tf.split 使用实例

本文详细介绍TensorFlow中tf.split函数的使用方法,包括参数解释、官方文档链接及代码示例,展示了如何将张量按指定维度和数量进行拆分。
部署运行你感兴趣的模型镜像

一、环境

TensorFlow API r1.14(rc)

CUDA 9.0 V9.0.176

Python 3.6.3

二、官方说明

把张量分解成子张量

https://tensorflow.google.cn/versions/r1.14/api_docs/python/tf/split

tf.split(
value,
num_or_size_splits,
axis=0,
num=None,
name=‘split’
)

参数:
value:要分割的张量
num_or_size_splits:可以是整数(指定把张量划分为几分,需要注意的是必须能整除 value.shape[axis])、一维张量或Python 列表(指定划分输出的每一个子张量的大小,需要注意的是张量或列表中的元素之和需要等于 value 要划分的维度的大小)
axis:整数或标量 int32 张量,指定沿那个维度分割张量,数值必须在 [-rank(value), rank(value)] 之间,默认是 0。
num:可选参数。当不能从 num_or_size_splits 的形状推断输出的数量时,通过该参数指定
name:可选参数。操作的名称

三、实例

>>> import tensorflow as tf
>>> import numpy as np
>>> tf.enable_eager_execution()
>>> data = np.random.random((5,10))
>>> data_tensor = tf.constant(data)

# 标量
>>> splited_interger = 5
>>> split0_0, split0_1, split0_2, split0_3, split0_4 = tf.split(data_tensor, splited_interger, -1)

>>> tf.shape(split0_0)
<tf.Tensor: id=9, shape=(2,), dtype=int32, numpy=array([5, 2], dtype=int32)>
>>> tf.shape(split0_1)
<tf.Tensor: id=11, shape=(2,), dtype=int32, numpy=array([5, 2], dtype=int32)>
>>> tf.shape(split0_2)
<tf.Tensor: id=13, shape=(2,), dtype=int32, numpy=array([5, 2], dtype=int32)>
>>> tf.shape(split0_3)
<tf.Tensor: id=15, shape=(2,), dtype=int32, numpy=array([5, 2], dtype=int32)>
>>> tf.shape(split0_4)
<tf.Tensor: id=17, shape=(2,), dtype=int32, numpy=array([5, 2], dtype=int32)>

# Python 列表
>>> splited_list = [2,3,5]
>>> split1_0, split1_1, split1_2 = tf.split(data_tensor, splited_list, -1)
>>> tf.shape(split1_0)

<tf.Tensor: id=24, shape=(2,), dtype=int32, numpy=array([5, 2], dtype=int32)>
>>> tf.shape(split1_1)
<tf.Tensor: id=26, shape=(2,), dtype=int32, numpy=array([5, 3], dtype=int32)>
>>> tf.shape(split1_2)
<tf.Tensor: id=28, shape=(2,), dtype=int32, numpy=array([5, 5], dtype=int32)>

# 1 维张量
>>> splited_tensor = tf.constant(splited_list)
>>> split2_0, split2_1, split2_2 = tf.split(data_tensor, splited_tensor, -1)
>>> tf.shape(split2_0)

<tf.Tensor: id=38, shape=(2,), dtype=int32, numpy=array([5, 2], dtype=int32)>
>>> tf.shape(split2_1)
<tf.Tensor: id=40, shape=(2,), dtype=int32, numpy=array([5, 3], dtype=int32)>
>>> tf.shape(split2_2)
<tf.Tensor: id=42, shape=(2,), dtype=int32, numpy=array([5, 5], dtype=int32)>

您可能感兴趣的与本文相关的镜像

TensorFlow-v2.15

TensorFlow-v2.15

TensorFlow

TensorFlow 是由Google Brain 团队开发的开源机器学习框架,广泛应用于深度学习研究和生产环境。 它提供了一个灵活的平台,用于构建和训练各种机器学习模型

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

csdn-WJW

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值