tensor切片的方法在实践中大量运用,其中涉及到多维度的切片操作,有时还是挺让人头晕的。
tf.gather()的下标取值和切片的方法:
import tensorflow as tf
from datetime import datetime
import numpy as np
def pprint(*args, **kwargs):
print(datetime.now(), *args, **kwargs, end='\n' + '*' * 50 + '\n')
params = tf.constant(['p0', 'p1', 'p2', 'p3', 'p4', 'p5'])
pprint(params[3].numpy()) # 获取第4个
pprint(tf.gather(params, 3).numpy()) # 获取第4个
pprint(tf.gather(params, indices=[2, 0, 2, 5]).numpy()) # 分别获取第3个,第1个,第3个,第6个
pprint(tf.gather(params, [[2, 0], [2, 5]]).numpy()) # 分别取下标的值,然后生成一个2 * 2的数组
params = tf.constant([[0, 1.0, 2.0],
[1
订阅专栏 解锁全文

被折叠的 条评论
为什么被折叠?



