tf.where tf.gather_nd用法

本文详细解析了TensorFlow中tf.where函数的使用方法,包括其返回true值位置的坐标、不同阶张量的处理方式及具体示例。此外,还介绍了tf.where的完整用法,包括条件参数、x和y参数的使用,以及tf.gather_nd函数的用法。
部署运行你感兴趣的模型镜像

tf.where

记录一下where的用法:
官方文档

tf.where(input, name=None)`
Returns locations of true values in a boolean tensor.

This operation returns the coordinates of true elements in input. The coordinates are returned in a 2-D tensor where the first dimension (rows) represents the number of true elements, and the second dimension (columns) represents the coordinates of the true elements. Keep in mind, the shape of the output tensor can vary depending on how many true values there are in input. Indices are output in row-major order.

For example:
# 'input' tensor is [[True, False]
#                    [True, False]]
# 'input' has two true values, so output has two coordinates.
# 'input' has rank of 2, so coordinates have two indices.
where(input) ==> [[0, 0],
                  [1, 0]]

# `input` tensor is [[[True, False][True, False]]
#                    [[False, True][False, True]]
#                    [[False, False][False, True]]]
# 'input' has 5 true values, so output has 5 coordinates.
# 'input' has rank of 3, so coordinates have three indices.
where(input) ==> [[0, 0, 0], [0, 1, 0],
                  [1, 0, 1],[1, 1, 1],
                  [2, 1, 1]]

改了一下格式一下子就很好理解了 = =
where返回的是为true位置的下标
第一个例子,是一个2阶(这个阶可以这么理解,如果是n个1y的张量就是2阶,如果是n个2y的张量就是3阶,n个3*y的张量还是3阶,阶是指的块数,还不理解就看第二个例子,直观的看,最后有几个]]]这个括号就是几阶,方便定位的…也不知道对不对(捂脸)),所以[0,0]和[0,1]指的是[0,0]和[0,1]这两个位置上是true,
第二个例子因为有三阶(就是三个块),所以输出是三列。每一个块都是[[True, False][True, False]]是一个2x2的矩阵。因为一共有5个true所以有五行。
[0,0,0] 指的是第0块的[0,0]位置,[0,1,0]指的是第0块的 [1,0]位置,以此类推。这样就搞懂惹
完整用法
(转自 https://blog.youkuaiyun.com/weixin_34318326/article/details/92119620):
格式:
tf.where(condition, x=None, y=None, name=None)

参数:
condition: 一个元素为bool型的tensor。元素内容为false,或true。
x: 一个和condition有相同shape的tensor,如果x是一个高维的tensor,x的第一维size必须和condition一样。
y: 和x有一样shape的tensor

返回:
一个和x,y有同样shape的tensor

功能:
遍历condition Tensor中的元素,如果该元素为true,则output Tensor中对应位置的元素来自x Tensor中对应位置的元素;否则output Tensor中对应位置的元素来自Y tensor中对应位置的元素。

tf.gather_nd

gather_nd=gather n dimension tensor

tf.gather_nd(
    params,
    indices,
    name=None
)

例子:

    params = [['a', 'b'], ['c', 'd']]
    indices = [[0, 0], [1, 1]]
    output = ['a', 'd']

一看就懂系列。

您可能感兴趣的与本文相关的镜像

TensorFlow-v2.15

TensorFlow-v2.15

TensorFlow

TensorFlow 是由Google Brain 团队开发的开源机器学习框架,广泛应用于深度学习研究和生产环境。 它提供了一个灵活的平台,用于构建和训练各种机器学习模型

--------------------------------------------------------------------------- ValueError Traceback (most recent call last) /tmp/ipykernel_553597/306262114.py in <module> 7 ## Training 8 # Epoch_list,Loss_list = model_train(batchsize,channel_SNR_db1,noise_init,nl_factor,eq_flag,norm_epsilon,earlystop_epoch) ----> 9 Epoch_list,Loss_list, Min_Distance_list = model_train(batchsize,channel_SNR_db1,noise_init,nl_factor,eq_flag,norm_epsilon,earlystop_epoch, min_distance_threshold=0.7,flags_schedule=[(1, 0), (0, 1), (1, 1)],iter_per_stage=50) /tmp/ipykernel_553597/4102420687.py in model_train(batchsize, channel_SNR, noise_init, nl_factor, eq_flag, epsilon, earlystop_epoch, min_distance_threshold, flags_schedule, iter_per_stage) 58 59 (batch_loss, batch_loss_Eq, NGMI, GMI, entropy_S, ---> 60 p_s, norm_constellation, x, min_distance) = train_step( 61 channel_SNR, noise_tf, GS_flag_now, PS_flag_now, eq_flag, epsilon, min_distance_threshold 62 ) ~/miniconda3/lib/python3.8/site-packages/tensorflow/python/util/traceback_utils.py in error_handler(*args, **kwargs) 151 except Exception as e: 152 filtered_tb = _process_traceback_frames(e.__traceback__) --> 153 raise e.with_traceback(filtered_tb) from None 154 finally: 155 del filtered_tb /tmp/__autograph_generated_file_jsnzuik.py in tf__train_step(inp_SNR, noise, GS_flag, PS_flag, eq_flag, epsilon, min_distance_threshold) 39 batch_size = ag__.converted_call(ag__.ld(tf).shape, (ag__.ld(p_s),), None, fscope)[0] 40 batch_indices = ag__.converted_call(ag__.ld(tf).tile, (ag__.converted_call(ag__.ld(tf).range, (ag__.ld(batch_size),), None, fscope)[:, ag__.ld(tf).newaxis, ag__.ld(tf).newaxis], [1, ag__.ld(M_int), ag__.ld(k)]), None, fscope) ---> 41 gather_indices = ag__.converted_call(ag__.ld(tf).stack, ([ag__.ld(batch_indices), ag__.converted_call(ag__.ld(tf).tile, (ag__.ld(topk_indices)[:, :, ag__.ld(tf).newaxis, :], [1, 1, ag__.ld(k), 1]), None, fscope)],), dict(axis=(- 1)), fscope) 42 neighbor_probs = ag__.converted_call(ag__.ld(tf).gather_nd, (ag__.ld(p_s), ag__.ld(gather_indices)), None, fscope) 43 neighbor_sum = ag__.converted_call(ag__.ld(tf).reduce_sum, (ag__.ld(neighbor_probs),), dict(axis=(- 1)), fscope) ValueError: in user code: File "/tmp/ipykernel_553597/675414708.py", line 77, in train_step * gather_indices = tf.stack([ ValueError: Shapes must be equal rank, but are 3 and 4 From merging shape 0 with other shapes. for '{{node stack_1}} = Pack[N=2, T=DT_INT32, axis=-1](Tile, Tile_1)' with input shapes: [1,8,3], [1,8,3,3].
08-22
pooled = [] box_to_level = [] for i, level in enumerate(range(2, 6)): ix = tf.where(tf.equal(roi_level, level)) level_boxes = tf.gather_nd(boxes, ix) # Box indicies for crop_and_resize. box_indices = tf.cast(ix[:, 0], tf.int32) # Keep track of which box is mapped to which level box_to_level.append(ix) # Stop gradient propogation to ROI proposals level_boxes = tf.stop_gradient(level_boxes) box_indices = tf.stop_gradient(box_indices) # Crop and Resize # From Mask R-CNN paper: "We sample four regular locations, so # that we can evaluate either max or average pooling. In fact, # interpolating only a single value at each bin center (without # pooling) is nearly as effective." # # Here we use the simplified approach of a single value per bin, # which is how it's done in tf.crop_and_resize() # Result: [batch * num_boxes, pool_height, pool_width, channels] pooled.append(tf.image.crop_and_resize( feature_maps[i], level_boxes, box_indices, self.pool_shape, method="bilinear")) # Pack pooled features into one tensor pooled = tf.concat(pooled, axis=0) # Pack box_to_level mapping into one array and add another # column representing the order of pooled boxes box_to_level = tf.concat(box_to_level, axis=0) box_range = tf.expand_dims(tf.range(tf.shape(box_to_level)[0]), 1) box_to_level = tf.concat([tf.cast(box_to_level, tf.int32), box_range], axis=1) # Rearrange pooled features to match the order of the original boxes # Sort box_to_level by batch then box index # TF doesn't have a way to sort by two columns, so merge them and sort. sorting_tensor = box_to_level[:, 0] * 100000 + box_to_level[:, 1] ix = tf.nn.top_k(sorting_tensor, k=tf.shape( box_to_level)[0]).indices[::-1] ix = tf.gather(box_to_level[:, 2], ix) pooled = tf.gather(pooled, ix) # Re-add the batch dimension pooled = tf.expand_dims(pooled, 0) return pooled ,详细解释这段代码
最新发布
11-21
@tf.function def train_step(inp_SNR, noise, GS_flag, PS_flag, eq_flag, epsilon=1e-12, min_distance_threshold=0.5): loss = 0 with tf.GradientTape() as tape: # 原始前向传播计算 s_logits = logit_model(inp_SNR) # batch_size = tf.shape(inp_SNR)[0] # s_logits = tf.zeros((batch_size, M), dtype=tf.float32) s = s_model(s_logits) soft_bits = soft_bit_encoder(s) hard_bits = hard_decision_on_bit(soft_bits) enc = Trans_model_bit(hard_bits) # 生成完整星座图 bit_set = tf.math.mod(tf.bitwise.right_shift(tf.expand_dims(symbol_set, 1), tf.range(bitlen)), 2) bit_set = tf.reverse(bit_set, axis=[-1]) constellation = Trans_model_bit(bit_set) constellation = tf.expand_dims(constellation, 0) # 归一化处理 p_s = tf.nn.softmax(s_logits) magnitudes = tf.abs(constellation) max_mag = tf.reduce_max(magnitudes) norm_factor = 1.30793 / tf.maximum(max_mag, epsilon) norm_constellation = r2c(norm_factor) * constellation x = r2c(norm_factor) * enc # === 星座点最小距离约束 === points = tf.squeeze(tf.stack([tf.math.real(norm_constellation), tf.math.imag(norm_constellation)], axis=-1)) diff = tf.expand_dims(points, 1) - tf.expand_dims(points, 0) # [M, M, 2] distances = tf.norm(diff, axis=-1) # [M, M] mask = tf.eye(tf.shape(distances)[0], dtype=tf.bool) valid_distances = tf.where(mask, tf.ones_like(distances)*1e10, distances) min_distance = tf.reduce_min(valid_distances) distance_penalty = tf.nn.relu(min_distance_threshold - min_distance) * 50.0 # === 新增:概率分布可逆性约束 === # 1. 计算初始均匀分布的熵(基准值) num_constellation_points = tf.cast(tf.shape(constellation)[1], tf.float32) # 使用换底公式计算log2: log2(x) = ln(x)/ln(2) uniform_entropy = tf.math.log(num_constellation_points) / tf.math.log(2.0) # 均匀分布的熵 # 2. 计算当前分布的熵 current_entropy = -tf.reduce_sum(p_s * tf.math.log(p_s) / tf.math.log(2.0)) # 以2为底的熵 # 3. 熵约束惩罚 entropy_ratio = current_entropy / uniform_entropy entropy_penalty = tf.nn.relu(0.9 - entropy_ratio) * 200.0 # 4. 概率下限约束 min_prob = tf.reduce_min(p_s) prob_floor_penalty = tf.nn.relu(epsilon - min_prob) * 200.0 # === 原始损失计算 === Tx = upsample_pulse_shaping(x, Fs, h_rrc, fa, fc) Rx = Tx + noise y = Model_Eq(Rx) entropy_S = -p_norm(p_s, p_s, lambda x: log2(x)) GMI = GMIcal_tf(x, tf.squeeze(y), M, norm_constellation, hard_bits_out, p_s) NGMI = 1 - (entropy_S - GMI) / bitlen loss_NGMI = tf.nn.relu(NGMI_th - NGMI) loss_Eq = tf.reduce_mean(tf.square(tf.abs(x - y))) # === 修改后的损失函数(添加所有惩罚项) === loss = (loss_Eq * eq_flag * 0.5 - GMI + loss_NGMI * 100 + distance_penalty + entropy_penalty # 新增:熵约束惩罚 + prob_floor_penalty) # 新增:概率下限惩罚 # # 梯度计算与更新 # variables = [] # if PS_flag == 1: # variables.extend(logit_model.trainable_variables) # variables.extend(s_model.trainable_variables) # if GS_flag == 1: # variables.extend(Trans_model_bit.trainable_variables) # if eq_flag == 1: # variables.extend(Model_Eq.trainable_variables) variables = (logit_model.trainable_variables * PS_flag + s_model.trainable_variables + Trans_model_bit.trainable_variables * GS_flag + Model_Eq.trainable_variables * eq_flag) gradients = tape.gradient(loss, variables) optimizer.apply_gradients(zip(gradients, variables)) # 保持原始返回值结构不变 return loss, loss_Eq, NGMI, GMI, tf.reduce_mean(entropy_S), p_s, norm_constellation, x, min_distance 新增约束条件,一个点与其相邻三个点的概率和不能超过4/M。当前代码可以正常运行,修改时只修改必要地方
08-22
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值