tensorflow 实现矩阵逻辑变换(tf.gather/tf.gather_nd)

本文探讨了PyTorch与TensorFlow中实现逻辑变换的方法。通过具体实例展示了如何利用PyTorch轻松实现逻辑变换,并介绍了在TensorFlow中采用不同思路实现相似功能的技术细节。
部署运行你感兴趣的模型镜像

Pytorch实现逻辑变换很容易,因为与numpy互相转换。但是,tensorflow实现可能有些困难,但是可以转换思路实现。

例如:

 pytorch  实现 lab[lab>10]=255,然后再将255去掉,只取非255部分,计算loss

 tensorflow 只需要计算loss时,tf.gather_nd(lab,index) ,index=tf.where(lab<=10)

使用例子:

import numpy as np
import tensorflow as tf

a=np.array(list(range(100)))
a=a.reshape([10,10])
b=tf.constant(a)

index=tf.where(b>50)
# 提取 多维度的矩阵逻辑操作
c=tf.gather_nd(b,index)
# 提取 一维的矩阵逻辑变换
c=tf.gather(b,index)

您可能感兴趣的与本文相关的镜像

PyTorch 2.5

PyTorch 2.5

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

--------------------------------------------------------------------------- ValueError Traceback (most recent call last) /tmp/ipykernel_553597/306262114.py in <module> 7 ## Training 8 # Epoch_list,Loss_list = model_train(batchsize,channel_SNR_db1,noise_init,nl_factor,eq_flag,norm_epsilon,earlystop_epoch) ----> 9 Epoch_list,Loss_list, Min_Distance_list = model_train(batchsize,channel_SNR_db1,noise_init,nl_factor,eq_flag,norm_epsilon,earlystop_epoch, min_distance_threshold=0.7,flags_schedule=[(1, 0), (0, 1), (1, 1)],iter_per_stage=50) /tmp/ipykernel_553597/4102420687.py in model_train(batchsize, channel_SNR, noise_init, nl_factor, eq_flag, epsilon, earlystop_epoch, min_distance_threshold, flags_schedule, iter_per_stage) 58 59 (batch_loss, batch_loss_Eq, NGMI, GMI, entropy_S, ---> 60 p_s, norm_constellation, x, min_distance) = train_step( 61 channel_SNR, noise_tf, GS_flag_now, PS_flag_now, eq_flag, epsilon, min_distance_threshold 62 ) ~/miniconda3/lib/python3.8/site-packages/tensorflow/python/util/traceback_utils.py in error_handler(*args, **kwargs) 151 except Exception as e: 152 filtered_tb = _process_traceback_frames(e.__traceback__) --> 153 raise e.with_traceback(filtered_tb) from None 154 finally: 155 del filtered_tb /tmp/__autograph_generated_file_jsnzuik.py in tf__train_step(inp_SNR, noise, GS_flag, PS_flag, eq_flag, epsilon, min_distance_threshold) 39 batch_size = ag__.converted_call(ag__.ld(tf).shape, (ag__.ld(p_s),), None, fscope)[0] 40 batch_indices = ag__.converted_call(ag__.ld(tf).tile, (ag__.converted_call(ag__.ld(tf).range, (ag__.ld(batch_size),), None, fscope)[:, ag__.ld(tf).newaxis, ag__.ld(tf).newaxis], [1, ag__.ld(M_int), ag__.ld(k)]), None, fscope) ---> 41 gather_indices = ag__.converted_call(ag__.ld(tf).stack, ([ag__.ld(batch_indices), ag__.converted_call(ag__.ld(tf).tile, (ag__.ld(topk_indices)[:, :, ag__.ld(tf).newaxis, :], [1, 1, ag__.ld(k), 1]), None, fscope)],), dict(axis=(- 1)), fscope) 42 neighbor_probs = ag__.converted_call(ag__.ld(tf).gather_nd, (ag__.ld(p_s), ag__.ld(gather_indices)), None, fscope) 43 neighbor_sum = ag__.converted_call(ag__.ld(tf).reduce_sum, (ag__.ld(neighbor_probs),), dict(axis=(- 1)), fscope) ValueError: in user code: File "/tmp/ipykernel_553597/675414708.py", line 77, in train_step * gather_indices = tf.stack([ ValueError: Shapes must be equal rank, but are 3 and 4 From merging shape 0 with other shapes. for '{{node stack_1}} = Pack[N=2, T=DT_INT32, axis=-1](Tile, Tile_1)' with input shapes: [1,8,3], [1,8,3,3].
08-22
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值