tf.reverse 函数的使用

本文详细介绍了TensorFlow中的tf.reverse函数,该函数用于沿指定轴翻转张量。文章通过实例展示了如何使用tf.reverse函数,并将其结果与numpy的切片操作进行对比,验证了其正确性。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

tf.reverse 函数用于将某个张量沿着某条轴倒转。

它的函数声明如下:

reverse(
    tensor,
    axis,
    name=None
)

参数:

tensor:是某个张量,数据类型可以是整数、浮点数、布尔、字符串、复数。秩最高可以为8.

axis:列表,如[-1]会倒转最后一条轴,[2]会倒转第二条轴(从0开始)。

name(可选):字符串,指定操作名。

返回:

一个与tensor有着相同类型和形状的张量。


它的操作与numpy有些切片操作等同。例如:

import tensorflow as tf
import numpy as np

x = np.random.randn(10, 224, 224, 3)
x_reversed1 = x[:, :, :, ::-1]   
x_reversed2 = x[:, ::-1, :, :]
x_reversed3 = x[:, ::-1, ::-1, :]

y1_1 = tf.reverse(x, axis=[-1])  # 可以使用负数指定某一轴
y1_2 = tf.reverse(x, axis=[3])
y2_1 = tf.reverse(x, axis=[1])
y2_2 = tf.reverse(x, axis=[-3])
y3 = tf.reverse(x, axis=[1, 2])  # 可以同时倒转多个轴

sess = tf.Session()

y1_1_v, y1_2_v, y2_1_v, y2_2_v, y3_v = sess.run([y1_1, y1_2, y2_1, y2_2, y3])

np.all(x_reversed1 == y1_1_v) # True
np.all(x_reversed1 == y1_2_v) # True
np.all(x_reversed2 == y2_1_v) # True
np.all(x_reversed2 == y2_2_v) # True
np.all(x_reversed3 == y3_v)   # True

 

@tf.function def train_step(inp_SNR, noise, GS_flag, PS_flag, eq_flag, epsilon=1e-12, min_distance_threshold=0.5): loss = 0 with tf.GradientTape() as tape: # 原始前向传播计算 s_logits = logit_model(inp_SNR) # batch_size = tf.shape(inp_SNR)[0] # s_logits = tf.zeros((batch_size, M), dtype=tf.float32) s = s_model(s_logits) soft_bits = soft_bit_encoder(s) hard_bits = hard_decision_on_bit(soft_bits) enc = Trans_model_bit(hard_bits) # 生成完整星座图 bit_set = tf.math.mod(tf.bitwise.right_shift(tf.expand_dims(symbol_set, 1), tf.range(bitlen)), 2) bit_set = tf.reverse(bit_set, axis=[-1]) constellation = Trans_model_bit(bit_set) constellation = tf.expand_dims(constellation, 0) # 归一化处理 p_s = tf.nn.softmax(s_logits) magnitudes = tf.abs(constellation) max_mag = tf.reduce_max(magnitudes) norm_factor = 1.30793 / tf.maximum(max_mag, epsilon) norm_constellation = r2c(norm_factor) * constellation x = r2c(norm_factor) * enc # === 星座点最小距离约束 === points = tf.squeeze(tf.stack([tf.math.real(norm_constellation), tf.math.imag(norm_constellation)], axis=-1)) diff = tf.expand_dims(points, 1) - tf.expand_dims(points, 0) # [M, M, 2] distances = tf.norm(diff, axis=-1) # [M, M] mask = tf.eye(tf.shape(distances)[0], dtype=tf.bool) valid_distances = tf.where(mask, tf.ones_like(distances)*1e10, distances) min_distance = tf.reduce_min(valid_distances) distance_penalty = tf.nn.relu(min_distance_threshold - min_distance) * 50.0 # === 新增:概率分布可逆性约束 === # 1. 计算初始均匀分布的熵(基准值) num_constellation_points = tf.cast(tf.shape(constellation)[1], tf.float32) # 使用换底公式计算log2: log2(x) = ln(x)/ln(2) uniform_entropy = tf.math.log(num_constellation_points) / tf.math.log(2.0) # 均匀分布的熵 # 2. 计算当前分布的熵 current_entropy = -tf.reduce_sum(p_s * tf.math.log(p_s) / tf.math.log(2.0)) # 以2为底的熵 # 3. 熵约束惩罚 entropy_ratio = current_entropy / uniform_entropy entropy_penalty = tf.nn.relu(0.9 - entropy_ratio) * 200.0 # 4. 概率下限约束 min_prob = tf.reduce_min(p_s) prob_floor_penalty = tf.nn.relu(epsilon - min_prob) * 200.0 # === 原始损失计算 === Tx = upsample_pulse_shaping(x, Fs, h_rrc, fa, fc) Rx = Tx + noise y = Model_Eq(Rx) entropy_S = -p_norm(p_s, p_s, lambda x: log2(x)) GMI = GMIcal_tf(x, tf.squeeze(y), M, norm_constellation, hard_bits_out, p_s) NGMI = 1 - (entropy_S - GMI) / bitlen loss_NGMI = tf.nn.relu(NGMI_th - NGMI) loss_Eq = tf.reduce_mean(tf.square(tf.abs(x - y))) # === 修改后的损失函数(添加所有惩罚项) === loss = (loss_Eq * eq_flag * 0.5 - GMI + loss_NGMI * 100 + distance_penalty + entropy_penalty # 新增:熵约束惩罚 + prob_floor_penalty) # 新增:概率下限惩罚 # # 梯度计算与更新 # variables = [] # if PS_flag == 1: # variables.extend(logit_model.trainable_variables) # variables.extend(s_model.trainable_variables) # if GS_flag == 1: # variables.extend(Trans_model_bit.trainable_variables) # if eq_flag == 1: # variables.extend(Model_Eq.trainable_variables) variables = (logit_model.trainable_variables * PS_flag + s_model.trainable_variables + Trans_model_bit.trainable_variables * GS_flag + Model_Eq.trainable_variables * eq_flag) gradients = tape.gradient(loss, variables) optimizer.apply_gradients(zip(gradients, variables)) # 保持原始返回值结构不变 return loss, loss_Eq, NGMI, GMI, tf.reduce_mean(entropy_S), p_s, norm_constellation, x, min_distance 新增约束条件,一个点与其相邻三个点的概率和不能超过4/M。当前代码可以正常运行,修改时只修改必要地方
最新发布
08-22
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值