人工智能AI编程基础

tensor切片的方法在实践中大量运用,其中涉及到多维度的切片操作,有时还是挺让人头晕的。

tf.gather()的下标取值和切片的方法:
import tensorflow as tf
from datetime import datetime
import numpy as np
 
 
def pprint(*args, **kwargs):
    print(datetime.now(), *args, **kwargs, end='\n' + '*' * 50 + '\n')
 
 
params = tf.constant(['p0', 'p1', 'p2', 'p3', 'p4', 'p5'])
pprint(params[3].numpy())  # 获取第4个
pprint(tf.gather(params, 3).numpy())  # 获取第4个
pprint(tf.gather(params, indices=[2, 0, 2, 5]).numpy())  # 分别获取第3个,第1个,第3个,第6个
pprint(tf.gather(params, [[2, 0], [2, 5]]).numpy())  # 分别取下标的值,然后生成一个2 * 2的数组
 
params = tf.constant([[0, 1.0, 2.0],
                      [1

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值