<think>我们希望在损失函数中增加一个新的约束:对于每个星座点,与其相邻的三个星座点(按距离计算)的概率之和不能超过4/M。我们将这个约束称为"局部概率约束"。
步骤:
1. 计算归一化后的星座点(norm_constellation)的坐标,形状为[1, M, 1](复数),我们将其转换为实数坐标,形状为[M,2](每个点用实部和虚部表示)。
2. 计算所有星座点之间的欧氏距离矩阵(M x M)。
3. 对于每个星座点,找到距离最近的三个星座点(不包括自身)。注意:距离矩阵的对角线是0(自身),所以我们需要排除自身。
4. 对于每个点i,获取其相邻的三个点(最近邻)的索引。
5. 计算每个点i的相邻三个点的概率和:sum_j( p_s[j] ),其中j是相邻三个点的索引。
6. 我们希望这个和不超过4/M(4是相邻点的个数?这里我们要求相邻三个点,所以应该是3个点?但是约束要求是4/M?注意:题目要求是“相邻三个点”,所以应该是3个点,但约束上限是4/M?可能4/M是一个预设的阈值,其中M是星座点数。我们按照题目要求:不能超过4/M。
7. 对每个点计算违反约束的程度:如果三个相邻点的概率和大于4/M,则惩罚为超过的部分;否则为0。
8. 将每个点的惩罚求和或平均,然后乘以一个惩罚系数(例如50.0)加到损失函数中。
注意:由于我们使用TensorFlow,需要确保所有操作是张量操作。
具体实现:
假设:
- p_s: 形状为[batch_size, M]的概率分布。注意:batch_size可能大于1,但星座图是共享的,所以每个样本的概率分布不同。但是,我们新增的约束是针对每个样本的概率分布的,所以我们需要对每个样本计算这个约束。
然而,在代码中,星座图是固定的(由Trans_model_bit生成),而p_s是每个输入样本对应的概率分布(由logit_model产生)。因此,我们需要对batch中的每个样本计算这个约束。
但是,注意:在损失函数中,我们计算距离矩阵时,星座图是固定的(与样本无关),所以每个样本看到的星座图是一样的。但是,每个样本有自己的概率分布p_s(形状为[batch_size, M])。因此,我们需要对batch中的每个样本计算其自己的局部概率约束。
步骤:
1. 计算星座点之间的距离矩阵(M x M),这个在batch内是固定的,所以只需要计算一次(在循环外计算?但是我们在函数内,而且星座图是归一化后的,所以每次前向传播都会计算一次,但同一个batch内不同样本的星座图是一样的,因为归一化因子是每个样本独立计算的?注意:归一化因子是每个样本计算的,所以每个样本的星座图归一化后可能不同?但是,我们归一化因子是每个样本计算自己的最大幅度,然后归一化,所以每个样本的星座点位置可能不同?)
然而,在代码中,归一化因子norm_factor是每个样本独立计算的,所以每个样本的星座图位置(norm_constellation)是不同的。因此,每个样本的星座点位置不同,那么每个样本的相邻点也会不同。因此,我们需要对每个样本计算距离矩阵,然后找到每个点的最近邻三个点。
但是,这样计算量很大(每个样本都要计算一个MxM的距离矩阵,然后排序找最近邻)。如果M很大(例如256),batch_size也很大,那么计算量会很大。我们需要权衡。
另一种思路:我们是否可以将星座图固定为归一化后的标准星座图(即不考虑每个样本的归一化差异)?但是,归一化因子是每个样本独立计算的,所以每个样本的星座图位置确实不同(因为归一化因子不同)。因此,我们必须每个样本单独计算。
因此,我们将在每个样本上计算以下操作:
步骤:
a. 获取当前样本的星座点:norm_constellation的形状为[1, M, 1](复数),我们将其转换为实数坐标:使用tf.squeeze去除多余的维度,然后使用tf.stack([tf.math.real(...), tf.math.imag(...)], axis=-1)得到形状为[M,2]的实数坐标。
b. 计算距离矩阵:使用tf.norm计算每两个点之间的欧氏距离,得到一个MxM的矩阵。
c. 将距离矩阵的对角线元素设置为一个很大的数(以避免选择自身作为最近邻)。
d. 使用tf.nn.top_k或tf.math.top_k来找到每个点最近的k个点(k=3)的距离和索引。注意,我们只需要最小的3个(不包括自身)。
e. 对于当前样本的概率分布p_s_i(形状为[M]),我们为每个点i,取其三个最近邻的索引,然后从p_s_i中取出这三个邻居的概率值,求和得到sum_neighbor_prob。
f. 计算每个点的约束违反:violation_i = tf.maximum(0, sum_neighbor_prob - 4.0/M)
g. 整个样本的约束惩罚为所有点违反程度之和(或平均)乘以一个系数。
由于有batch,我们需要对每个样本计算这个惩罚,然后取平均(或求和)作为总的惩罚项。
注意:在代码中,p_s的形状为[batch_size, M],所以我们需要循环处理每个样本?或者使用向量化操作?
向量化方法:
我们可以使用tf.vectorized_map?或者使用tf.map_fn?但这样可能效率不高。我们可以尝试用纯张量操作。
另一种思路:
1. 计算距离矩阵:形状为[batch_size, M, M](因为每个样本有自己的星座图)
2. 对于每个样本,我们都要计算一个距离矩阵,然后找到每个点的最近3个邻居(排除自身)。
但是,计算这样的距离矩阵需要大量的内存(O(batch_size * M * M)),所以如果batch_size和M很大,可能不现实。
权衡:由于我们已经在使用自定义训练循环,而且这个约束是新增的,我们可以考虑在batch_size较小的情况下运行。或者,我们只对星座图进行全局归一化(即使用一个固定的归一化因子,而不是每个样本独立归一化),这样星座图在batch内就是固定的,那么距离矩阵只需要计算一次(整个batch共享)。但是,这样会改变原来的归一化方式,可能会影响性能。
考虑到原代码中每个样本独立归一化是为了保证每个样本的发射信号功率归一化,这是必要的。因此,我们只能接受每个样本独立计算距离矩阵。
由于内存和计算量的考虑,我们可能需要对batch_size进行限制,或者使用循环逐个样本处理(在batch_size维度上循环)。
在TensorFlow中,我们可以使用tf.map_fn来对batch进行循环处理,但这样可能会比较慢。
我们决定:在batch_size较小的情况下(例如32或64),我们直接使用向量化计算整个batch的距离矩阵。如果M=16(较小),那么batch_size=32时,距离矩阵大小为[32,16,16],内存占用为32*16*16*4(float32)= 32KB,这是可以接受的。但如果M=256,那么32*256*256*4=8MB,也是可以接受的(在GPU上)。因此,我们采用向量化计算整个batch。
步骤:
1. 获取norm_constellation:当前norm_constellation是通过归一化每个样本的星座图得到的,形状为[batch_size, M, 1](复数)?注意:在代码中,norm_constellation = r2c(norm_factor) * constellation,其中constellation是[1, M, 1](因为bit_set是固定的,所以星座图初始是固定的,然后乘以每个样本的归一化因子(形状为[batch_size,1,1])?)
实际上,在代码中:
norm_factor 是每个样本的标量(形状为[batch_size]?),然后通过r2c(norm_factor)将其转换为复数(形状为[batch_size,1,1]),然后乘以constellation(形状为[1, M, 1]),通过广播,得到norm_constellation的形状为[batch_size, M, 1]。
所以,norm_constellation的形状为[batch_size, M, 1](复数)。
2. 将norm_constellation转换为实数坐标:使用tf.complex实部和虚部分开,然后合并成一个形状为[batch_size, M, 2]的张量。
points = tf.stack([tf.math.real(norm_constellation), tf.math.imag(norm_constellation)], axis=-1)
然后去掉最后一个维度(因为norm_constellation是三维,最后维是1,所以stack后变成[batch_size, M, 1, 2]?),所以我们需要先squeeze掉最后一个维度(大小为1)?
实际上,norm_constellation的形状是[batch_size, M, 1](复数),所以:
real_part = tf.math.real(norm_constellation) # [batch_size, M, 1]
imag_part = tf.math.imag(norm_constellation) # [batch_size, M, 1]
然后合并:points = tf.concat([real_part, imag_part], axis=-1) # [batch_size, M, 2] (因为最后一维变成了2)
3. 计算距离矩阵:
diff = tf.expand_dims(points, 2) - tf.expand_dims(points, 1) # [batch_size, M, 1, 2] - [batch_size, 1, M, 2] -> [batch_size, M, M, 2]
distances = tf.norm(diff, axis=-1) # [batch_size, M, M]
4. 将距离矩阵的对角线(即自身到自身)设置为一个很大的数(例如1e10),这样在找最近邻时不会选到自己。
mask = tf.eye(M, dtype=tf.bool) # [M, M]
mask = tf.expand_dims(mask, 0) # [1, M, M]
valid_distances = tf.where(mask, tf.ones_like(distances)*1e10, distances)
5. 找到每个点的最近3个邻居(k=3):
topk_values, topk_indices = tf.math.top_k(-valid_distances, k=3+1)
# 注意:我们取负距离然后取topk(相当于取最小距离),k=3+1是因为我们可能包括自身?但是我们设置了mask,自身距离被设为1e10,所以不会选到自身。但是,为了保险,我们取k=3即可?因为已经排除了自身,所以取3个。
# 但是,我们设置mask排除了自身,所以取3个就是最近的3个邻居。
# 但是,top_k返回的是沿着最后一个维度的topk,所以我们需要指定axis=-1。
# 但是,valid_distances的形状是[batch_size, M, M],我们在最后一个维度(索引为2)上取topk,即对每个点(每个样本的每个点)在M个距离中取最小的3个(除了自身)。
# 因此,我们调用:
topk_values, topk_indices = tf.math.top_k(-valid_distances, k=3) # 取负距离,然后取topk(最大的k个负距离,即最小的k个距离)
# 得到的topk_values的形状为[batch_size, M, 3],topk_indices的形状为[batch_size, M, 3]
6. 对于每个样本的每个点i,我们有3个邻居的索引(topk_indices[i, j, :]),然后我们取出这3个邻居在概率分布p_s中的概率值。
p_s的形状为[batch_size, M](每个样本有M个点的概率)
我们需要为每个样本的每个点i,取出其3个邻居的概率值。我们可以使用tf.gather,但需要按样本和点来索引。
使用tf.gather_nd?或者使用tf.batch_gather?我们可以这样:
首先,将p_s扩展一个维度:p_s_expanded = tf.expand_dims(p_s, axis=-1) # [batch_size, M, 1]
然后,我们可以使用tf.gather,但需要将索引调整成合适的形式。
另一种方法:使用tf.one_hot将索引转换为掩码,然后求和。但这样可能复杂。
我们使用tf.gather_along_axis:
neighbor_probs = tf.gather(p_s_expanded, topk_indices, axis=1, batch_dims=1)
# 注意:p_s_expanded的形状为[batch_size, M, 1],topk_indices的形状为[batch_size, M, 3],我们希望在axis=1(即M所在的维度)上收集,并且batch_dims=1(表示第一个维度是batch,并且索引的batch维度与p_s_expanded的batch维度对齐)。
但是,tf.gather_along_axis的用法:
neighbor_probs = tf.gather(p_s_expanded, topk_indices, axis=1, batch_dims=1)
# 这样,我们收集到的形状为[batch_size, M, 3, 1]?然后我们squeeze掉最后一个维度:tf.squeeze(neighbor_probs, axis=-1) -> [batch_size, M, 3]
然后,对邻居概率求和:sum_neighbor_probs = tf.reduce_sum(neighbor_probs, axis=-1) # [batch_size, M]
7. 计算违反约束的程度:我们希望每个点的邻居概率和不超过4/M(注意:这里4/M,M是星座点数)。所以:
threshold = 4.0 / M # 标量
violations = tf.nn.relu(sum_neighbor_probs - threshold) # [batch_size, M]
然后,整个batch的惩罚项为:对每个样本的每个点的违反程度求和(或平均),然后乘以一个惩罚系数(例如50.0)。
local_prob_penalty = tf.reduce_mean(violations) * 50.0 # 这里使用平均,乘以50.0
8. 将local_prob_penalty加到损失函数中。
注意:M是星座点数,我们可以用M = tf.cast(tf.shape(constellation)[1], tf.float32)来获取。
但是,注意:在计算距离矩阵时,我们使用了M(整数),而threshold中的M是浮点数。
修改代码:
在计算距离矩阵之前,我们需要获取M(整数形式):
M_int = tf.shape(p_s)[1] # 因为p_s的形状为[batch_size, M]
M_float = tf.cast(M_int, tf.float32)
然后,threshold = 4.0 / M_float
另外,在设置mask时,我们使用M_int来创建单位矩阵。
考虑到计算效率,我们可能需要在GPU上运行,所以向量化操作是可行的。
但是,如果M很大(比如256),计算整个batch的距离矩阵([batch_size,256,256])需要的内存是batch_size*256*256*4(字节),对于batch_size=64,需要64*256*256*4=16MB,这可以接受。
因此,我们按照上述步骤实现。
代码修改位置:在计算完norm_constellation之后,在计算距离惩罚(distance_penalty)的后面,我们添加这个新的约束。
注意:由于我们已经在计算距离矩阵(用于最小距离约束),我们可以复用这个距离矩阵吗?但是,最小距离约束计算的是最小距离(所有点对的最小值),而这里我们需要每个点的最近邻。而且,最小距离约束中,我们计算的距离矩阵是每个样本独立的(因为归一化因子不同),所以这里我们也要计算每个样本的距离矩阵。但是,我们之前已经计算过一次距离矩阵(用于最小距离约束),但那个计算是针对星座图归一化后的点(即norm_constellation)的,而且我们计算的是实数坐标。所以,我们可以复用那个距离矩阵吗?
在原始代码中,最小距离约束部分:
points = tf.squeeze(tf.stack([tf.math.real(norm_constellation),
tf.math.imag(norm_constellation)], axis=-1))
# 注意:norm_constellation的形状是[batch_size, M, 1](复数),所以stack后形状为[batch_size, M, 1, 2]?然后squeeze掉第2维(大小为1)?所以变成[batch_size, M, 2]
# 然后计算diff:扩展维度得到[batch_size, M, 1, 2]和[batch_size, 1, M, 2],相减得到[batch_size, M, M, 2],然后求范数得到[batch_size, M, M]
但是,在原始代码中,最小距离约束部分计算的是每个样本的星座图的距离矩阵,然后取整个距离矩阵的最小值(排除对角线)。但是,我们这里需要每个点的最近邻三个点,所以需要完整的距离矩阵。
因此,我们可以复用原始代码中已经计算的距离矩阵吗?原始代码中,距离矩阵被计算为distances(形状为[batch_size, M, M]),然后用于计算最小距离。但是,在原始代码中,这个距离矩阵并没有保存下来,因为后面我们直接用了valid_distances来求最小值。所以,如果我们想复用,需要将这部分代码调整一下。
但是,为了代码清晰,我们选择重新计算一次(因为计算距离矩阵的代码并不复杂,而且我们需要的点坐标已经有了)。而且,我们新增的约束需要每个点的邻居,而原始的最小距离约束只需要全局最小距离,所以我们可以独立计算。
因此,我们将在原始代码中计算最小距离约束的部分之后,添加新的约束计算。
但是,注意:在原始代码中,计算最小距离约束时,已经将星座点转换成了实数坐标(points),并且计算了距离矩阵(distances)。我们可以保存这个距离矩阵,避免重复计算。
修改原始代码:将最小距离约束部分计算的距离矩阵保存下来,然后用于新的约束。
步骤:
1. 在原始代码中,最小距离约束部分已经计算了:
points = ... # [batch_size, M, 2]
diff = ... # [batch_size, M, M, 2]
distances = ... # [batch_size, M, M]
2. 然后,我们在这个部分之后,使用这个distances来计算新的约束。
这样,我们避免了重复计算距离矩阵。
因此,我们修改原始代码:
首先,将最小距离约束部分计算的距离矩阵保存下来(不要覆盖),然后我们使用这个距离矩阵来计算最近邻。
但是,注意:最小距离约束部分在计算valid_distances时,将自身距离设为了1e10,然后求全局最小值。这个操作不影响我们使用这个距离矩阵(因为我们也需要排除自身)。
所以,我们可以复用这个distances(在设置mask之前?)还是之后?实际上,我们计算最近邻时也需要排除自身,所以应该使用valid_distances(即已经将自身设为大数的矩阵)。因此,我们保存valid_distances。
修改原始代码:
在最小距离约束部分,我们计算了:
distances = tf.norm(diff, axis=-1) # [batch_size, M, M]
mask = tf.eye(M, dtype=tf.bool) # 注意:这里M是固定的,但M应该用当前样本的星座点数?我们用M_int
# 修改:使用M_int来创建单位矩阵
mask = tf.eye(M_int, dtype=tf.bool) # [M_int, M_int]
mask = tf.expand_dims(mask, 0) # [1, M_int, M_int]
valid_distances = tf.where(mask, tf.ones_like(distances)*1e10, distances)
min_distance = tf.reduce_min(valid_distances) # 标量(整个batch的最小值?注意:reduce_min会reduce所有维度,所以得到整个batch中所有点对的最小距离)
然后,我们保存valid_distances(因为后面新约束要用)。
所以,我们在最小距离约束部分之后,添加新约束的计算。
代码结构:
# === 星座点最小距离约束 ===
... # 计算points, diff, distances, mask, valid_distances, min_distance, distance_penalty
# === 新增:局部概率约束 ===
# 使用valid_distances(形状为[batch_size, M, M])和p_s(形状为[batch_size, M])
M_float = tf.cast(M_int, tf.float32)
threshold = 4.0 / M_float
# 找到每个点的最近3个邻居的索引
# 注意:valid_distances中自身已经被设为大数,所以取最小的3个就是最近邻
topk_values, topk_indices = tf.math.top_k(-valid_distances, k=3) # 取负距离,然后取topk(最大的k个负距离,即最小的k个距离)-> [batch_size, M, 3]
# 注意:topk_indices是索引,表示在最后一个维度(星座点索引)的位置。
# 收集邻居概率
p_s_expanded = tf.expand_dims(p_s, axis=-1) # [batch_size, M, 1]
# 使用gather_nd? 或者使用gather_along_axis?
# 方法:使用tf.gather,但需要将索引和p_s_expanded对齐
neighbor_probs = tf.gather(p_s_expanded, topk_indices, axis=1, batch_dims=1) # [batch_size, M, 3, 1]
neighbor_probs = tf.squeeze(neighbor_probs, axis=-1) # [batch_size, M, 3]
sum_neighbor_probs = tf.reduce_sum(neighbor_probs, axis=-1) # [batch_size, M]
# 计算违反约束:超过threshold的部分
violations = tf.nn.relu(sum_neighbor_probs - threshold) # [batch_size, M]
# 整个batch的惩罚:取所有违反约束的平均值,然后乘以系数
local_prob_penalty = tf.reduce_mean(violations) * 50.0 # 标量
然后,将local_prob_penalty加到损失函数中。
最后,修改损失函数:
loss = (loss_Eq * eq_flag * 0.5
- GMI
+ loss_NGMI * 100
+ distance_penalty
+ entropy_penalty
+ prob_floor_penalty
+ local_prob_penalty) # 新增局部概率约束惩罚
注意:我们乘以50.0是惩罚系数,可以根据效果调整。
另外,在最小距离约束部分,我们原来使用了一个固定的M(比如16)来创建单位矩阵?但是,M应该是动态获取的(即星座点数)。在原始代码中,我们使用了:
mask = tf.eye(tf.shape(distances)[0], dtype=tf.bool)
但是,tf.shape(distances)[0]是batch_size?不对,distances的形状是[batch_size, M, M],所以tf.shape(distances)[0]是batch_size,而我们需要的是M。所以,原始代码有误?
在原始代码中,计算最小距离约束部分:
points = tf.squeeze(tf.stack([tf.math.real(norm_constellation),
tf.math.imag(norm_constellation)], axis=-1))
diff = tf.expand_dims(points, 1) - tf.expand_dims(points, 0) # [batch_size, M, 1, 2] - [batch_size, 1, M, 2] -> [batch_size, M, M, 2]
distances = tf.norm(diff, axis=-1) # [batch_size, M, M]
mask = tf.eye(tf.shape(distances)[0], dtype=tf.bool) # 这里tf.shape(distances)[0]是batch_size,所以创建的是[batch_size, batch_size]的矩阵,但我们需要的是[M, M]的矩阵。所以这里错了!
原始代码中,mask的创建有误。我们应该使用M(星座点数)来创建单位矩阵,而不是batch_size。因此,我们需要修改:
mask = tf.eye(tf.shape(distances)[1], dtype=tf.bool) # 因为distances的形状为[batch_size, M, M],所以索引1是M
mask = tf.expand_dims(mask, 0) # [1, M, M]
然后,valid_distances = tf.where(mask, ...) 会广播到整个batch。
因此,我们修正原始代码中的mask创建。
总结修改:
1. 修正原始最小距离约束的mask创建(使用distances的第1维,即M)。
2. 在最小距离约束部分之后,添加局部概率约束的计算。
3. 将局部概率约束的惩罚项(local_prob_penalty)加到损失函数中。
注意:由于我们复用了距离矩阵,所以不需要重新计算。
代码修改如下(只修改必要部分):</think>为了添加新的约束条件(每个星座点与其相邻的三个点的概率和不能超过4/M),我们需要在现有代码中增加以下步骤:
1. 在计算距离矩阵后,找出每个点的最近邻三个点
2. 计算每个点的邻居概率和
3. 计算违反约束的惩罚项
4. 将惩罚项加入损失函数
以下是修改后的代码(只修改必要部分):
```python
@tf.function
def train_step(inp_SNR, noise, GS_flag, PS_flag, eq_flag, epsilon=1e-12, min_distance_threshold=0.5):
loss = 0
with tf.GradientTape() as tape:
# ... [前面的代码保持不变] ...
# === 星座点最小距离约束 ===
points = tf.squeeze(tf.stack([tf.math.real(norm_constellation),
tf.math.imag(norm_constellation)], axis=-1))
diff = tf.expand_dims(points, 1) - tf.expand_dims(points, 0) # [batch_size, M, M, 2]
distances = tf.norm(diff, axis=-1) # [batch_size, M, M]
# 修正mask创建:使用星座点数M而不是batch_size
M_int = tf.shape(distances)[1] # 获取星座点数(整数)
mask = tf.eye(M_int, dtype=tf.bool) # [M_int, M_int]
mask = tf.expand_dims(mask, 0) # [1, M_int, M_int] 用于广播
valid_distances = tf.where(mask, tf.ones_like(distances)*1e10, distances)
min_distance = tf.reduce_min(valid_distances)
distance_penalty = tf.nn.relu(min_distance_threshold - min_distance) * 50.0
# === 新增:局部概率约束(每个点与其最近3个点的概率和≤4/M) ===
M_float = tf.cast(M_int, tf.float32) # 星座点数(浮点数)
k = 3 # 最近邻数量
# 1. 找到每个点的最近k个邻居(排除自身)
_, topk_indices = tf.math.top_k(-valid_distances, k=k) # 取负距离找最小距离
# 2. 收集邻居概率值
batch_size = tf.shape(p_s)[0]
batch_indices = tf.tile(tf.range(batch_size)[:, tf.newaxis, tf.newaxis],
[1, M_int, k]) # [batch_size, M, k]
gather_indices = tf.stack([
batch_indices,
tf.tile(topk_indices[:, :, tf.newaxis, :], [1, 1, k, 1]) # 调整维度对齐
], axis=-1) # [batch_size, M, k, 2]
neighbor_probs = tf.gather_nd(p_s, gather_indices) # [batch_size, M, k]
# 3. 计算邻居概率和并检查约束
neighbor_sum = tf.reduce_sum(neighbor_probs, axis=-1) # [batch_size, M]
threshold = 4.0 / M_float
violations = tf.nn.relu(neighbor_sum - threshold) # [batch_size, M]
# 4. 计算惩罚项(平均违反程度×系数)
local_prob_penalty = tf.reduce_mean(violations) * 50.0
# === 概率分布可逆性约束 ===
# ... [保持不变] ...
# === 修改后的损失函数(添加局部概率约束惩罚) ===
loss = (loss_Eq * eq_flag * 0.5
- GMI
+ loss_NGMI * 100
+ distance_penalty
+ entropy_penalty
+ prob_floor_penalty
+ local_prob_penalty) # 新增局部概率约束惩罚
# 梯度计算与更新
variables = (logit_model.trainable_variables * PS_flag +
s_model.trainable_variables +
Trans_model_bit.trainable_variables * GS_flag +
Model_Eq.trainable_variables * eq_flag)
gradients = tape.gradient(loss, variables)
optimizer.apply_gradients(zip(gradients, variables))
# 返回结果(添加局部概率约束相关指标)
return loss, loss_Eq, NGMI, GMI, tf.reduce_mean(entropy_S), p_s, norm_constellation, x, min_distance, local_prob_penalty
```
### 关键修改解释:
1. **距离矩阵修正**:
- 修正了`mask`的创建方式,使用星座点数`M_int`而不是batch_size
- 确保`valid_distances`正确排除自身距离
2. **局部概率约束实现**:
- **邻居识别**:使用`tf.math.top_k`在`valid_distances`上找到最近3个邻居
- **概率收集**:通过`tf.gather_nd`高效收集邻居概率值
- **约束计算**:计算邻居概率和并检查是否超过`4/M`阈值
- **惩罚项**:对违反约束的程度取平均并乘以惩罚系数(50.0)
3. **损失函数更新**:
- 新增`local_prob_penalty`项,防止概率集中在局部区域
4. **返回值扩展**:
- 添加`local_prob_penalty`到返回结果,便于监控约束效果
### 注意事项:
1. 该约束会增加计算复杂度(O(batch_size×M²)),但现代GPU通常能高效处理
2. 惩罚系数50.0是启发式设置,可根据训练效果调整
3. 当星座点数M较小时(如<16),建议减少邻居数k
4. 新增约束可能影响收敛速度,建议监控各项损失分量