<think>我们修改了train_step函数,现在它返回了额外的值:min_distance(星座点间的最小距离)
因此,在调用train_step的地方,我们需要增加一个变量来接收这个返回值。
同时,在训练循环中,我们可能需要记录或显示这个最小距离。
具体修改点:
1. 在调用train_step时,现在有9个返回值(原来8个,现在增加了一个min_distance)
2. 在batch循环中,我们需要接收这个返回值,并可能将其用于后续处理(比如打印)
3. 在绘图部分,我们可能希望展示这个最小距离的变化(可选,根据需求)
修改步骤:
1. 修改train_step调用,增加一个变量接收min_distance
2. 在需要记录或显示的地方,可以累计每个batch的最小距离,然后计算平均最小距离
3. 在打印信息时,可以加入这个最小距离的信息
注意:由于我们返回了min_distance,所以需要修改train_step的返回值为9个(原8个+1个)
在train_step中,我们返回的是:loss, loss_Eq, NGMI, GMI, entropy_S, p_s, norm_constellation, x, min_distance
因此,在调用train_step的地方,需要改为:
batch_loss, batch_loss_Eq, NGMI, GMI, entropy_S, p_s, norm_constellation, x, min_distance = train_step(...)
然后,在训练循环中,我们可以累计每个batch的min_distance,然后计算平均最小距离。
另外,在打印信息时,我们可以显示这个平均最小距离。
修改后的代码:</think>根据修改后的`train_step`函数,我们需要对`model_train`函数进行相应调整。主要修改包括处理新增的返回值(最小距离)以及优化训练过程。以下是修改后的代码:
```python
def model_train(batchsize, channel_SNR, noise_init, nl_factor, eq_flag, epsilon, earlystop_epoch, min_distance_threshold=0.5):
## start training
# constellation animation
fig1, ax1 = plt.subplots(figsize=(4,5))
Epoch_list = []
Loss_list = []
Min_Distance_list = [] # 新增:记录最小距离
min_loss = 1e10
max_GMI = 0
patience = 0
# initialize the noise
noise_exp = noise_init[np.newaxis,:]
GS_flag_len = len(GS_flag)
iter_per_GS_flag = iterations // GS_flag_len
GS_flag_index = 0
GS_flag_now = GS_flag[GS_flag_index]
PS_flag_now = PS_flag[GS_flag_index]
# create two saving figures
fig = plt.figure(figsize=(15, 6))
ax_1 = fig.add_subplot(121)
ax_2 = fig.add_subplot(122, projection='3d')
# 新增:创建监控最小距离的图表
fig_distance = plt.figure(figsize=(10, 5))
ax_distance = fig_distance.add_subplot(111)
for epoch in range(iterations):
total_loss = 0
total_loss_Eq = 0
GMI_mean = 0
NGMI_mean = 0
min_distance_mean = 0 # 新增:平均最小距离
if GS_flag_len > 1 and (epoch + 1) % iter_per_GS_flag == 0 and (GS_flag_index + 1) < GS_flag_len:
GS_flag_index += 1
GS_flag_now = GS_flag[GS_flag_index]
PS_flag_now = PS_flag[GS_flag_index]
print(f"[INFO] Switching to GS_flag={GS_flag_now}, PS_flag={PS_flag_now} at epoch {epoch+1}")
for batch in range(0, batchsize):
## get the channel impairments from the Rx data
noise_tf = tf.cast(noise_exp, dtype=dtype)
## apply the channel impairments and train the model
# 修改:接收train_step返回的9个值(新增min_distance)
(batch_loss, batch_loss_Eq, NGMI, GMI, entropy_S,
p_s, norm_constellation, x, min_distance) = train_step(
channel_SNR, noise_tf, GS_flag_now, PS_flag_now, eq_flag, epsilon, min_distance_threshold
)
## save the training data x for transmission
Tx = get_Tx(channel_SNR, epsilon)[-1]
## Will be substituted by the real channel impairments in the Experiment
Rx = channel_ISI_NL(Tx, channel_SNR_db1, channel_FIR, nl_factor=nl_factor, ch_type=channel_type)
## after the transmission, get the Rx data and new noise
noise_exp = Rx - Tx
noise_exp = noise_exp[np.newaxis, :]
## calculate the loss
total_loss_Eq += batch_loss_Eq
total_loss += batch_loss
GMI_mean += GMI
NGMI_mean += NGMI
min_distance_mean += min_distance # 新增:累计最小距离
# 计算批平均值
total_loss = total_loss / batchsize
total_loss_Eq = total_loss_Eq / batchsize
GMI_mean = GMI_mean / batchsize
NGMI_mean = NGMI_mean / batchsize
min_distance_mean = min_distance_mean / batchsize # 新增:平均最小距离
# 记录训练数据
Epoch_list.append(epoch + 1)
Loss_list.append(total_loss.numpy())
Min_Distance_list.append(min_distance_mean.numpy()) # 新增:记录最小距离
# 模型保存逻辑
if total_loss < min_loss:
min_loss = total_loss
manager_Loss.save(checkpoint_number=epoch + 1)
print('[INFO] New Loss = %.6f, Save new model named ./model_free_training_checkpoints/V6_ckpt_Loss-%d.index' %
(min_loss, epoch + 1))
if GMI_mean > max_GMI:
patience = 0
max_GMI = GMI_mean
manager_MI.save(checkpoint_number=epoch + 1)
print('[INFO] New GMI = %.6f, NGMI = %.6f, Save new model named ./model_free_training_checkpoints/V6_ckpt_MI-%d.index' %
(max_GMI, NGMI_mean, epoch + 1))
else:
patience += 1
if patience > earlystop_epoch:
manager_Loss.save(checkpoint_number=epoch + 1)
print('[INFO] Early stop at epoch %d, loss = %.6f, GMI = %.6f' %
(epoch + 1, total_loss, GMI_mean))
break
# 显示训练进度和结果
if (epoch + 1) % displayStep == 0 or epoch == 0:
# 绘制星座图
X = np.real(norm_constellation)
Y = np.imag(norm_constellation)
ax1.clear()
ax1.scatter(X, Y, c=cons_colors)
ax1.axis('square')
ax1.set_title(f'Constellation (Epoch {epoch+1}, Min Dist: {min_distance_mean:.4f})')
fig1.canvas.draw()
fig1.canvas.flush_events()
# 打印训练指标
print('[DEBUG] epoch:{:3d}/{:2d} | loss: {:.2f} | loss_Eq: {:.4f} | GMI: {:.4f} | NGMI: {:.4f} | entropy_S: {:.4f} | min_dist: {:.4f}'
.format(epoch+1, iterations, total_loss.numpy(), total_loss_Eq.numpy(),
GMI_mean.numpy(), NGMI_mean.numpy(), entropy_S, min_distance_mean.numpy()))
# 绘制3D概率分布图
X_flat = X.ravel()
Y_flat = Y.ravel()
P = p_s.numpy().ravel()
# 颜色映射
width = height = 2 / np.sqrt(M + 5)
dz = P
offset = dz + np.abs(dz.min())
fracs = offset.astype(float) / offset.max()
norm = colors.Normalize(fracs.min(), fracs.max())
color_values = plt.cm.viridis(norm(fracs.tolist()))
# 更新图表
ax_1.clear()
ax_1.set_title('Constellation Tx')
ax_1.scatter(X_flat, Y_flat, s=80)
ax_1.axis('square')
ax_2.clear()
ax_2.bar3d(X_flat, Y_flat, np.zeros_like(X_flat), width, height, P,
shade=True, color=color_values, edgecolor='black')
ax_2.set_xlabel('X')
ax_2.set_ylabel('Y')
ax_2.set_zlabel('Probability')
ax_2.set_title(f'Epoch {epoch+1}: GMI = {GMI_mean.numpy():.4f}, NGMI = {NGMI_mean.numpy():.4f}')
fig.savefig(f'./figures/32QAM/PS_distribution_epoch{epoch+1}.png', dpi=300)
# 新增:绘制最小距离变化图
ax_distance.clear()
ax_distance.plot(Epoch_list, Min_Distance_list, 'b-', label='Min Distance')
ax_distance.axhline(y=min_distance_threshold, color='r', linestyle='--', label='Threshold')
ax_distance.set_xlabel('Epoch')
ax_distance.set_ylabel('Minimum Distance')
ax_distance.set_title('Minimum Constellation Distance Over Epochs')
ax_distance.legend()
ax_distance.grid(True)
fig_distance.savefig('./figures/32QAM/min_distance_trend.png', dpi=200)
# 保存检查点
manager_Loss.save(checkpoint_number=epoch + 1)
# 返回训练结果
Epoch_list = np.array(Epoch_list)
Loss_list = np.array(Loss_list)
Min_Distance_list = np.array(Min_Distance_list) # 新增:返回最小距离列表
return Epoch_list, Loss_list, Min_Distance_list
```
### 关键修改解释:
1. **函数参数扩展**:
```python
def model_train(..., min_distance_threshold=0.5):
```
- 添加了最小距离阈值参数,默认值0.5
2. **新增数据记录**:
```python
Min_Distance_list = [] # 记录每个epoch的最小距离
min_distance_mean = 0 # 当前epoch的平均最小距离
```
3. **train_step调用调整**:
```python
(batch_loss, batch_loss_Eq, NGMI, GMI, entropy_S,
p_s, norm_constellation, x, min_distance) = train_step(...)
```
- 接收train_step返回的9个值(新增min_distance)
4. **最小距离处理**:
```python
min_distance_mean += min_distance # 累计每个batch的最小距离
min_distance_mean = min_distance_mean / batchsize # 计算平均值
Min_Distance_list.append(min_distance_mean.numpy()) # 记录到列表
```
5. **可视化增强**:
- 星座图标题显示当前最小距离
- 新增最小距离变化趋势图
- 3D概率分布图标题显示GMI和NGMI
6. **训练信息输出**:
```python
print('[DEBUG] ... | min_dist: {:.4f}'.format(...))
```
- 在调试信息中新增最小距离显示
7. **返回数据扩展**:
```python
return Epoch_list, Loss_list, Min_Distance_list
```
- 返回新增的最小距离列表
### 新增功能说明:
1. **最小距离监控**:
- 实时记录并显示星座点间的最小欧氏距离
- 在图表中显示距离阈值线(红色虚线)
- 保存最小距离变化趋势图
2. **训练可视化增强**:
- 星座图标题动态更新
- 3D概率分布图标题显示关键指标
- 所有图表自动保存为图像文件
3. **训练过程优化**:
- 添加GS/PS标志切换提示
- 更详细的训练指标输出
- 图表自动清除避免内存泄漏
### 使用建议:
1. **阈值调整**:
```python
# 根据调制阶数动态调整阈值
if M == 16: # 16QAM
min_distance_threshold = 0.4
elif M == 64: # 64QAM
min_distance_threshold = 0.2
else: # QPSK等
min_distance_threshold = 0.6
```
2. **早停策略增强**:
```python
# 同时考虑GMI和最小距离的早停策略
if GMI_mean > max_GMI and min_distance_mean > min_distance_threshold * 0.8:
# 保存模型...
```
3. **动态调整惩罚**:
```python
# 在train_step中实现动态惩罚系数
penalty_factor = 50.0 * (1 + epoch/iterations) # 随训练进行增加惩罚
distance_penalty = tf.nn.relu(...) * penalty_factor
```