tf.split()在keras中切割张量

本文深入讲解TensorFlow中split函数的使用方法,包括tf.split的基本语法、在Keras中通过Lambda层切割张量的具体操作,以及如何调整切割后的张量顺序并进行还原。适合于需要理解和运用TensorFlow进行数据处理的开发者。
部署运行你感兴趣的模型镜像

TensorFlow中的split函数

1. tf.split()函数

tf.split(
    value,
    num_or_size_splits,
    axis=0,
    num=None,
    name='split'
)

value:传入的tensor,就是传入的矩阵或高维矩阵
num_or_size_splits:切成几份,比如输入的是2,那么会在传入的axis上切成2份。举个栗子:2 x 32 x 32的tensor,在axis = 0上切2份,就是两个1 x 32 x 32,具体顺序如何,一会再讨论。
axis:在哪个轴上进行切割,比如2 x 32 x 32,依次对应axis = 0 , 1 和 2

2. tf.split()在keras中切割张量

Lambda层

keras.layers.Lambda(function, output_shape=None, mask=None, arguments=None)

如果split()想要用到keras中,就必须套入Lambda,作为神经网络的一层出现。具体写法如下:

x = Lambda(tf.split, arguments={'axis': 2, 'num_or_size_splits': 4})(input_tensor)

Lambda层的第一个参数是要作为层出现的函数,第二个参数形式为字典,指的是要传入前面函数的参数,其中key是函数API中定义的参数名称,value是要传入的参数。
这里需要注意,tensor这个参数作为Lambda的层的输入,写在最后的(input_tensor)里

Lambda层的详细介绍见keras中文文档

切割张量的排列顺序

以 input_tensor.shape = (2, 32, 32) 为例,输入到网络的 shape 应该是
(?, 2, 32, 32)
用如下 split() 进行切割

x = Lambda(tf.split, arguments={'axis': 2, 'num_or_size_splits': 4})(input_tensor)

x 的形状应该是

# x.shape = (4,?,2,8,32)

即 split() 会将切割的片段放到axis = 0的位置

我的理解是split()先将 ? x 2 x 32 x 32 切成 ? x 2 x 4 x 8 x 32,再转置成4 x ? x 2 x 8 x32
为什么有如上的理解,和下面要说的还原有关

它们具体是如何排列的,只需要 print 一下 tensor 就可以了!

调整顺序

tf.transpose() 转置函数,第一个参数是tensor,第二个参数是axis的顺序

x = Lambda(tf.transpose, arguments={'perm': [1,2,0,3,4]})(x)
# x.shape = (?,2,4,8,32)

还原

还原成最初的样子以及顺序

# 直接reshape
x = Reshape((2, 32, 32,))(x)

这里由于调整顺序部分已经将tensor的顺序调整为2 x 4 x 8 x 32了,即我理解的split()切割的第一步,因此只需要reshape就可以还原回去了。

注:不知道是不是可以写成 2 x 8 x 4 x 32(也许可以…)

您可能感兴趣的与本文相关的镜像

TensorFlow-v2.9

TensorFlow-v2.9

TensorFlow

TensorFlow 是由Google Brain 团队开发的开源机器学习框架,广泛应用于深度学习研究和生产环境。 它提供了一个灵活的平台,用于构建和训练各种机器学习模型

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值