https://blog.youkuaiyun.com/weixin_34613450/article/details/82725036
https://blog.youkuaiyun.com/kdongyi/article/details/82910632
split(
value,
num_or_size_splits,
axis=0,
num=None,
name='split'
)
输入:
value:输入的tensor
num_or_size_splits:如果是个整数n,就将输入的tensor分为n个子tensor。如果是个tensor T,就将输入的tensor分为len(T)个子tensor
axis:默认为0,指明在哪个维度上分
举例:
import tensorflow as tf
import numpy as np
value = np.array([[1,2,3],
[4,5,6]])
#axis = 0,表示按行切分,切分的子tensor列数不变
#将value按行切分成两份,第一份0行3列,第2份2行3列
split1,split2 = tf.split(value,[0,2],0)
#axis = 1,表示按列切分,切分的子tensor行数不变
#将value按列切分乘三份,第一份2行1列,第二份2行0列,第三份2行2列
split3,split4,split5 = tf.split(value, [1, 0, 2], 1)
with tf.Session() as sess:
print("split1:\n")
print(sess.run(split1))
print("split2:\n")
print(sess.run(split2))
print("split3:\n")
print(sess.run(split3))
print("split4:\n")
print(sess.run(split4))
print("split5:\n")
print(sess.run(split5))
结果如下:
split1:
[]
split2:
[[1 2 3]
[4 5 6]]
split3:
[[1]
[4]]
split4:
[]
split5:
[[2 3]
[5 6]]