tf.train.batch,tf.train.batch_join

本文深入解析TensorFlow中tf.train.batch和tf.train.batch_join函数的使用方法及区别,通过实例演示了如何从不同数据集中批量抽取数据,展示了enqueue_many参数的作用,并对比了两种函数在数据处理上的异同。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

用到的测试数据如下:

t1 = tf.constant(np.array([0])) # (1, ) data
t2 = tf.constant(np.array([3]))
t3 = tf.constant(np.array([4]))
t4 = tf.constant(np.array([5]))
t0 = tf.constant(np.array([[0], [3], [4], [5]])) # (4, 1) data
t5 = tf.constant(np.array([1, 2, 7, 8]))  # ( 4, ) label
t0_ = tf.constant(np.array([[00], [33], [44], [55]])) 
t5_ = tf.constant(np.array([11, 22, 77, 88]))
li = [[[t1, t2, t3, t4], t5], [t0_, t5_]]  # for tf.train.batch_join
li2 = [t0_, t5_] # for tf.train.batch

       li 最外面的括号代表一整个tensors_list,[[t1, t2, t3, t4], t5]=[datas,labels],[t1, t2, t3, t4]代表了四个data(分别为0,3,4,5),size都是(1, )。
       [t1, t2, t3, t4] 和 t0_ 这两种写法效果是一样的。t0相当于把[t1, t2, t3, t4]合成了一个tensor,第一维度是数据的数量。
       li2是li中的一个数据集,这就是tf.train.batch和tf.train.batch_join的区别。
       tf.train.batch_join一次从好几个数据集中提取数据,然后合并起来。
       tf.train.batch从一个数据集中提取数据。
下面是测试代码:
tf.train.batch, enqueue_many=True

y1 = tf.train.batch(li2, shapes=[(1, ), ()], batch_size=3, enqueue_many=True)
print(sess.run(y1))
print(sess.run(y1))
print(sess.run(y1))

       enqueue_many为True的意思是[t0_, t5_] 对应了很多个样本,t0_和t5_的第一维度大小对应着样本的数量。改变第一维度可以得到4个data和对应的label。
       shapes对应data和label的size,这里分别是(1,)和( )。
结果:

[array([[ 0],
       [33],
       [44]]), array([11, 22, 77])]
[array([[55],
       [ 0],
       [33]]), array([88, 11, 22])]
[array([[44],
       [55],
       [ 0]]), array([77, 88, 11])]

tf.train.batch, enqueue_many=False

y1 = tf.train.batch(li2, shapes=[(4, 1), (4,)], batch_size=3, enqueue_many=False)
print(sess.run(y1))

       enqueue_many为False的意思是[t0_, t5_] 对应一个样本。所以这里t0就是data,t5就是label。
       shapes对应data和label的size,分别是(4,1)和(4, )。
结果:

[array([[[ 0],
        [33],
        [44],
        [55]],

       [[ 0],
        [33],
        [44],
        [55]],

       [[ 0],
        [33],
        [44],
        [55]]]), array([[11, 22, 77, 88],
       [11, 22, 77, 88],
       [11, 22, 77, 88]])]

tf.train.batch_join, enqueue_many=False

y1 = tf.train.batch_join(li, shapes=[(4, 1), (4,)], batch_size=4, enqueue_many=False)
print(sess.run(y1))
print(sess.run(y1))
print(sess.run(y1))

       enqueue_many和shapes的意思和tf.train.batch中的一样。但是取值可以从两个数据集中取。
结果:

[array([[[ 0],
        [33],
        [44],
        [55]],

       [[ 0],
        [ 3],
        [ 4],
        [ 5]],

       [[ 0],
        [33],
        [44],
        [55]]]), array([[11, 22, 77, 88],
       [ 1,  2,  7,  8],
       [11, 22, 77, 88]])]
[array([[[ 0],
        [ 3],
        [ 4],
        [ 5]],

       [[ 0],
        [33],
        [44],
        [55]],

       [[ 0],
        [ 3],
        [ 4],
        [ 5]]]), array([[ 1,  2,  7,  8],
       [11, 22, 77, 88],
       [ 1,  2,  7,  8]])]
[array([[[ 0],
        [33],
        [44],
        [55]],

       [[ 0],
        [ 3],
        [ 4],
        [ 5]],

       [[ 0],
        [33],
        [44],
        [55]]]), array([[11, 22, 77, 88],
       [ 1,  2,  7,  8],
       [11, 22, 77, 88]])]

tf.train.batch_join, enqueue_many=True

y1 = tf.train.batch_join(li, shapes=[(1, ), ()], batch_size=3, enqueue_many=True)
print(sess.run(y1))
print(sess.run(y1))
print(sess.run(y1))

结果

[array([[0],
       [3],
       [4]]), array([1, 2, 7])]
[array([[5],
       [0],
       [3]]), array([8, 1, 2])]
[array([[4],
       [5],
       [0]]), array([ 7,  8, 11])]

可以看出输出的batch中数据是从两个数据集中随机取的。

全部代码

t1 = tf.constant(np.array([0]))
t2 = tf.constant(np.array([3]))
t3 = tf.constant(np.array([4]))
t4 = tf.constant(np.array([5]))
t0 = tf.constant(np.array([[0], [3], [4], [5]]))
t5 = tf.constant(np.array([1, 2, 7, 8]))
t0_ = tf.constant(np.array([[00], [33], [44], [55]]))  # 4*1
t5_ = tf.constant(np.array([11, 22, 77, 88]))

li = [[[t1, t2, t3, t4], t5], [t0_, t5_]]
li2 = [t0_, t5_]

with tf.Session() as sess:
    y1 = tf.train.batch_join(li, shapes=[(1, ), ()], batch_size=3, enqueue_many=True)
    # y1 = tf.train.batch(li2, shapes=[(4, 1), (4,)], batch_size=3, enqueue_many=False)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    
    print(sess.run(y1))
    print(sess.run(y1))
    print(sess.run(y1))

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值