@tf.function
def train_step(inp_SNR, noise, GS_flag, PS_flag, eq_flag, epsilon=1e-12, min_distance_threshold=0.5):
loss = 0
with tf.GradientTape() as tape:
# 原始前向传播计算
s_logits = logit_model(inp_SNR)
# batch_size = tf.shape(inp_SNR)[0]
# s_logits = tf.zeros((batch_size, M), dtype=tf.float32)
s = s_model(s_logits)
soft_bits = soft_bit_encoder(s)
hard_bits = hard_decision_on_bit(soft_bits)
enc = Trans_model_bit(hard_bits)
# 生成完整星座图
bit_set = tf.math.mod(tf.bitwise.right_shift(tf.expand_dims(symbol_set, 1), tf.range(bitlen)), 2)
bit_set = tf.reverse(bit_set, axis=[-1])
constellation = Trans_model_bit(bit_set)
constellation = tf.expand_dims(constellation, 0)
# 归一化处理
p_s = tf.nn.softmax(s_logits)
magnitudes = tf.abs(constellation)
max_mag = tf.reduce_max(magnitudes)
norm_factor = 1.30793 / tf.maximum(max_mag, epsilon)
norm_constellation = r2c(norm_factor) * constellation
x = r2c(norm_factor) * enc
# === 星座点最小距离约束 ===
points = tf.squeeze(tf.stack([tf.math.real(norm_constellation),
tf.math.imag(norm_constellation)], axis=-1))
diff = tf.expand_dims(points, 1) - tf.expand_dims(points, 0) # [M, M, 2]
distances = tf.norm(diff, axis=-1) # [M, M]
mask = tf.eye(tf.shape(distances)[0], dtype=tf.bool)
valid_distances = tf.where(mask, tf.ones_like(distances)*1e10, distances)
min_distance = tf.reduce_min(valid_distances)
distance_penalty = tf.nn.relu(min_distance_threshold - min_distance) * 50.0
# === 新增:概率分布可逆性约束 ===
# 1. 计算初始均匀分布的熵(基准值)
num_constellation_points = tf.cast(tf.shape(constellation)[1], tf.float32)
# 使用换底公式计算log2: log2(x) = ln(x)/ln(2)
uniform_entropy = tf.math.log(num_constellation_points) / tf.math.log(2.0) # 均匀分布的熵
# 2. 计算当前分布的熵
current_entropy = -tf.reduce_sum(p_s * tf.math.log(p_s) / tf.math.log(2.0)) # 以2为底的熵
# 3. 熵约束惩罚
entropy_ratio = current_entropy / uniform_entropy
entropy_penalty = tf.nn.relu(0.9 - entropy_ratio) * 200.0
# 4. 概率下限约束
min_prob = tf.reduce_min(p_s)
prob_floor_penalty = tf.nn.relu(epsilon - min_prob) * 200.0
# === 原始损失计算 ===
Tx = upsample_pulse_shaping(x, Fs, h_rrc, fa, fc)
Rx = Tx + noise
y = Model_Eq(Rx)
entropy_S = -p_norm(p_s, p_s, lambda x: log2(x))
GMI = GMIcal_tf(x, tf.squeeze(y), M, norm_constellation, hard_bits_out, p_s)
NGMI = 1 - (entropy_S - GMI) / bitlen
loss_NGMI = tf.nn.relu(NGMI_th - NGMI)
loss_Eq = tf.reduce_mean(tf.square(tf.abs(x - y)))
# === 修改后的损失函数(添加所有惩罚项) ===
loss = (loss_Eq * eq_flag * 0.5
- GMI
+ loss_NGMI * 100
+ distance_penalty
+ entropy_penalty # 新增:熵约束惩罚
+ prob_floor_penalty) # 新增:概率下限惩罚
# # 梯度计算与更新
# variables = []
# if PS_flag == 1:
# variables.extend(logit_model.trainable_variables)
# variables.extend(s_model.trainable_variables)
# if GS_flag == 1:
# variables.extend(Trans_model_bit.trainable_variables)
# if eq_flag == 1:
# variables.extend(Model_Eq.trainable_variables)
variables = (logit_model.trainable_variables * PS_flag +
s_model.trainable_variables +
Trans_model_bit.trainable_variables * GS_flag +
Model_Eq.trainable_variables * eq_flag)
gradients = tape.gradient(loss, variables)
optimizer.apply_gradients(zip(gradients, variables))
# 保持原始返回值结构不变
return loss, loss_Eq, NGMI, GMI, tf.reduce_mean(entropy_S), p_s, norm_constellation, x, min_distance
新增约束条件,一个点与其相邻三个点的概率和不能超过4/M。当前代码可以正常运行,修改时只修改必要地方