if Eq_retrain == 1 and Retrain_flag == 0:
v = 0
c = (-v) * (np.abs(cons_init)**2)
fenzi = np.exp(c)
fenmu = np.sum(fenzi)
P_MB = fenzi / fenmu
P_MB = P_MB / np.sum(P_MB)
entropy_S_MB = -p_norm(P_MB, P_MB, lambda x: np.log2(x))
power = np.abs(cons_init) ** 2
avg_power = np.mean(power)
norm_epsilon = 1e-12
if avg_power < norm_epsilon:
scale_factor = 1.0
else:
scale_factor = np.sqrt(power_limation / avg_power)
norm_cons_MB = cons_init * scale_factor
Tx_cons_MB = np.array(random.choices(cons_init, weights=P_MB, k=batchSize))
Tx_cons_MB = Tx_cons_MB * scale_factor
noise = np.random.normal(0, 1, size=Tx_cons_MB.shape) + 1j * np.random.normal(0, 1, size=Tx_cons_MB.shape)
noise = np.complex64(noise)
cons_MB = Tx_cons_MB + 0.02 * noise
cmap = plt.get_cmap('Purples')
color_norm = colors.SymLogNorm(linthresh=2500, linscale=3000,
vmin=P_MB.min(), vmax=P_MB.max()*900, base=20)
fig11, ax11 = plt.subplots()
heatmap, xedges, yedges = np.histogram2d(np.real(cons_MB.ravel()),
np.imag(cons_MB.ravel()),
bins=500, density=True)
extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
ax11.imshow(10*heatmap.T, norm=color_norm, extent=extent,
origin='lower', cmap=cmap)
ax11.axis((-1, 1, -1, 1))
ax11.set_title('MB Constellation')
fig11.show()
Tx_label_set = []
Rx_rand_set = []
for i in range(nBatches * 2):
Tx_cons_MB = np.array(random.choices(cons_init, weights=P_MB, k=batchSize))
Tx_cons_MB = Tx_cons_MB * scale_factor
Tx_cons_MB_RRC = upsample_pulse_shaping(Tx_cons_MB, Fs, h_rrc, fa, fc, draw = False)
Tx_length = tf.shape(Tx_cons_MB_RRC)[0]
num_zeros = int(len(channel_FIR)/2-number_f)
zeros_tx = tf.zeros(shape=(num_zeros,), dtype=Tx_cons_MB_RRC.dtype)
Tx_mapped = tf.concat([zeros_tx, Tx_cons_MB_RRC], axis=0)
Rx_MB = channel_ISI_NL(Tx_mapped, channel_SNR_db1, channel_FIR,nl_factor=nl_factor, ch_type=channel_type)
Rx_MB = Rx_MB[0:Tx_length]
Tx_label_set.append(Tx_cons_MB)
Rx_rand_set.append(Rx_MB)
if Eq_retrain == 1 and Retrain_flag == 0:
class EveryNEpochsLog(tf.keras.callbacks.Callback):
def __init__(self, n=100):
super(EveryNEpochsLog, self).__init__()
self.n = n
def on_epoch_end(self, epoch, logs=None):
if logs is None:
logs = {}
if (epoch + 1) % self.n == 0:
loss = logs.get('loss', 'unknown')
val_loss = logs.get('val_loss', 'unknown')
print(f"Epoch {epoch+1}/{self.params['epochs']}: loss={loss:.6f}, val_loss={val_loss:.6f}")
Eq_inp = Input(shape=(batchSize*Fs,),dtype=dtype)
Eq_inp_1 = tf.expand_dims(Eq_inp,axis=-1)
Eq_1 = Conv1D(filters=8, kernel_size=32, padding='SAME', activation='relu',name='nl_filters1')(Eq_inp_1)
Eq_1 = Conv1D(filters=8, kernel_size=32, padding='SAME', activation='relu',name='nl_filters2')(Eq_1)
Eq_out = Conv1D(filters=1, kernel_size=1, padding='SAME', activation='linear',name='Eq_out')(Eq_1)
Demod_out = downconvert_matched_filter(Eq_out, Fs, h_rrc, fa, fc)
Model_Eq = Model(inputs = [Eq_inp], outputs = [Demod_out], name = 'Model_Eq')
Tx_label_set = np.array(Tx_label_set)
Rx_rand_set = np.array(Rx_rand_set)
Model_Eq.compile(optimizer='adam',loss='mse')
Model_Eq.summary()
model_Eq_ckpt = ModelCheckpoint(
filepath=eq_model_path,
monitor='val_loss',
save_best_only=True,
save_weights_only=False,
verbose=0, # 关闭ModelCheckpoint的默认输出
save_freq='epoch'
)
every_100_log = EveryNEpochsLog(n=100)
Loss = Model_Eq.fit(
Rx_rand_set[0:nBatches],
Tx_label_set[0:nBatches],
validation_data=(Rx_rand_set[nBatches:], Tx_label_set[nBatches:]),
epochs=1000,
verbose=0, # 关闭默认的进度条输出
callbacks=[model_Eq_ckpt, every_100_log] # 添加自定义回调
)
Model_Eq = tf.keras.models.load_model(eq_model_path)
Rx_Eq = Model_Eq(Rx_MB[np.newaxis,:])
print('[DEBUG] Equalizer retrained, loss: %f,'%Loss.history['loss'][-1], 'output shape:',Rx_Eq.shape)
fig02, ax02 = plt.subplots()
ax02.plot(np.arange(0,len(Loss.history['loss'])),np.log10(Loss.history['loss']),'r',label="Train loss")
ax02.plot(np.arange(0,len(Loss.history['val_loss'])),np.log10(Loss.history['val_loss']),'b',label="Val loss")
ax02.legend(loc='upper right')
ax02.set_xlabel('Epoch')
ax02.set_ylabel('log10(Loss)')
else:
Model_Eq = tf.keras.models.load_model(eq_model_path)
print('[INFO] Load Equalizer')
# 第四步修改,增加了误符号率的计算
v=0
c=(-v)*(np.abs(cons_init)**2)
fenzi=np.exp(c)
fenmu=np.sum(fenzi)
P_MB=fenzi/fenmu
P_MB=P_MB/np.sum(P_MB) # 简化重复赋值
entropy_S_MB = -p_norm(P_MB, P_MB, lambda x: log2(x))
symbol_power = np.abs(cons_init) ** 2
weighted_avg_power = np.sum(P_MB * symbol_power)
norm_epsilon = 1e-12
if weighted_avg_power < norm_epsilon:
scale_factor = 1.0
else:
scale_factor = np.sqrt(power_limation / weighted_avg_power)
norm_factor_MB = r2c(scale_factor)
norm_cons_MB = norm_factor_MB * cons_init
# 初始化错误计数器
total_symbols = 0
symbol_errors_equalized = 0
symbol_errors_matched = 0
GMI_MonteCarlo = []
for i in range(nBatches):
Tx_cons_MB = np.array(random.choices(cons_init, weights=P_MB, k=batchSize))
Tx_cons_MB = r2c(norm_factor_MB) * Tx_cons_MB
Tx_cons_MB_RRC = upsample_pulse_shaping(Tx_cons_MB, Fs, h_rrc, fa, fc)
Tx_length = tf.shape(Tx_cons_MB_RRC)[0]
num_zeros = int(len(channel_FIR)/2-number_f)
zeros_tx = tf.zeros(shape=(num_zeros,), dtype=Tx_cons_MB_RRC.dtype)
Tx_mapped = tf.concat([zeros_tx, Tx_cons_MB_RRC], axis=0)
Rx_MB = channel_ISI_NL(Tx_mapped, channel_SNR_db1, channel_FIR,nl_factor=nl_factor, ch_type=channel_type)
Rx_MB = Rx_MB[0:Tx_length]
Rx_ME_eq = Model_Eq(Rx_MB[np.newaxis, :])
Tx_label = Tx_cons_MB.numpy().ravel().astype(np.complex64)
Rx_label = Rx_ME_eq.numpy().ravel().astype(np.complex64)
decisions_equalized = np.array([
norm_cons_MB[np.argmin(np.abs(s - norm_cons_MB))]
for s in Rx_label
])
symbol_errors_equalized += np.sum(~np.isclose(decisions_equalized, Tx_label, atol=1e-3))
total_symbols += len(Tx_label)
GMI_MonteCarlo.append(GMIcal(Tx_label, Rx_label, M, norm_cons_MB.numpy(), hard_bits_out, P_MB))
last_Tx_cons_MB = np.array(random.choices(cons_init, weights=P_MB, k=batchSize))
last_Tx_cons_MB = r2c(norm_factor_MB) * last_Tx_cons_MB
last_Tx_cons_MB_RRC = upsample_pulse_shaping(last_Tx_cons_MB, Fs, h_rrc, fa, fc)
Tx_length = tf.shape(last_Tx_cons_MB_RRC)[0]
num_zeros = int(len(channel_FIR)/2-number_f)
zeros_tx = tf.zeros(shape=(num_zeros,), dtype=last_Tx_cons_MB_RRC.dtype)
Tx_mapped = tf.concat([zeros_tx, last_Tx_cons_MB_RRC], axis=0)
last_Rx_MB = channel_ISI_NL(Tx_mapped, channel_SNR_db1, channel_FIR,nl_factor=nl_factor, ch_type=channel_type)
last_Rx_MB = last_Rx_MB[0:Tx_length]
y_match = downconvert_matched_filter(last_Rx_MB, Fs, h_rrc, fa, fc)
y_match = phase_noise_estimation(last_Tx_cons_MB, y_match)
y_match_np = y_match.numpy().astype(np.complex64) if isinstance(y_match, tf.Tensor) else y_match.astype(np.complex64)
y_match_flat = y_match_np.reshape(-1)
tx_power = power_limation
y_power = np.mean(np.abs(y_match_flat)** 2)
norm_scale = np.sqrt(tx_power / (y_power + norm_epsilon))
y_match_flat_normalized = (y_match_flat * norm_scale).astype(np.complex64)
last_Tx_label = last_Tx_cons_MB.numpy().ravel().astype(np.complex64)
match_length = min(len(y_match_flat_normalized), len(last_Tx_label))
y_match_flat_normalized = y_match_flat_normalized[:match_length]
last_Tx_label = last_Tx_label[:match_length]
decisions_matched = np.array([
norm_cons_MB[np.argmin(np.abs(s - norm_cons_MB.numpy()))]
for s in y_match_flat_normalized
])
symbol_errors_matched = np.sum(~np.isclose(decisions_matched, last_Tx_label, atol=1e-3))
SER_matched = symbol_errors_matched / len(last_Tx_label) if len(last_Tx_label) > 0 else 0
SER_equalized = symbol_errors_equalized / total_symbols if total_symbols > 0 else 0
fig0, ax0 = plt.subplots(1, 2)
fig0.set_size_inches(10, 5)
ax0[0].set_title('Matched Filter (Normalized)')
ax0[0].scatter(np.real(y_match_flat_normalized), np.imag(y_match_flat_normalized),
s=10, label='Received')
ax0[0].scatter(np.real(norm_cons_MB), np.imag(norm_cons_MB),
s=50, c='red', marker='x', label='Tx Constellation')
ax0[0].legend(loc='upper left')
last_Rx_ME_eq = Model_Eq(last_Rx_MB[np.newaxis, :])
last_Rx_label = last_Rx_ME_eq.numpy().ravel().astype(np.complex64)[:match_length]
ax0[1].set_title('Equalized')
ax0[1].scatter(np.real(last_Rx_label), np.imag(last_Rx_label), s=10, label='Received')
ax0[1].scatter(np.real(norm_cons_MB), np.imag(norm_cons_MB),
s=50, c='red', marker='x', label='Tx Constellation')
ax0[1].legend(loc='upper left')
fig0.show()
GMI_MB = np.mean(GMI_MonteCarlo)
NGMI_MB = 1 - (entropy_S_MB - GMI_MB) / bitlen
print(f'[DEBUG] SER (Equalized) = {SER_equalized:.6f}')
print(f'[DEBUG] SER (Matched Filter) = {SER_matched:.6f}')
print(f'[DEBUG] MB_GMI = {GMI_MB:.2f}, MB_NGMI = {NGMI_MB:.2f}, entropy = {entropy_S_MB:.2f}, {channel_SNR_db} dB SNR')
这是一个均衡器的训练与测试代码,我需要你进行以下修改,首先我已经得到了信道冲激响应CIR:
channel_FIR = np.loadtxt(channel_FIR_path)
filter_len = channel_FIR.shape[0]
将上述均衡器中的卷积层(也就是不包括downconvert_matched_filter部分),替换为基于MMSE的均衡器,该均衡器的初始抽头系数由CIR推导计算得到,在后续的训练和实际应用中可以自动根据误差进行修正。之前的downconvert_matched_filter在模型外实现。
第二,该模型需要参与到tensorflow图模式下的更大的模型训练中,需要你注意模型的实现方式,并且需要该模型能够在其他的函数中调用。
第三,将其中的功率归一化修改为幅度归一化,通过整体缩放的方式将最大幅度限制为3。
第四,均衡器的输入与输出都应该是实数。均衡器的输出应该经过downconvert_matched_filter和相位恢复与发送端的符号数据进行比较。
第五,保留现在的训练和测试方式,不要将其封装为函数
def phase_noise_estimation(x, y):
phase_noise = tf.reduce_mean(tf.math.angle(tf.math.conj(x)*y))
y_phase_corrected = y*tf.math.exp(-1j*r2c(phase_noise))
return y_phase_corrected
最新发布