Eq_inp = Input(shape=(batchSize*Fs,),dtype=dtype)
Eq_inp_1 = tf.expand_dims(Eq_inp,axis=-1)
### Directly input the undownsampled matched filter output to the equalizer
Eq_1 = Conv1D(filters=8, kernel_size=32, padding='SAME', activation='relu',name='nl_filters1')(Eq_inp_1)
Eq_1 = Conv1D(filters=8, kernel_size=32, padding='SAME', activation='relu',name='nl_filters2')(Eq_1)
Eq_out = Conv1D(filters=1, kernel_size=1, padding='SAME', activation='linear',name='Eq_out')(Eq_1)
### CAP demodulation
y_match_real = tf.nn.conv1d(Eq_out, g_cos, stride=1, padding='SAME')
y_match_imag = tf.nn.conv1d(Eq_out, -g_sin, stride=1, padding='SAME')
y_match_real_ds = y_match_real[:,0::Fs,0]
y_match_imag_ds = y_match_imag[:,0::Fs,0]
Demod_out = tf.complex(y_match_real_ds,y_match_imag_ds)
Model_Eq = Model(inputs = [Eq_inp], outputs = [Demod_out], name = 'Model_Eq')
## train the equalizer
Tx_label_set = np.array(Tx_label_set)
Rx_rand_set = np.array(Rx_rand_set)
Model_Eq.compile(optimizer='adam',loss='mse')
Model_Eq.summary()
model_Eq_ckpt = ModelCheckpoint(filepath='./models/Model_Eq_bit'+str(bitlen)+'bits.hdf5',monitor='val_loss',save_best_only=True,save_weights_only=False,verbose=1,save_freq='epoch')
Loss = Model_Eq.fit(Rx_rand_set[0:nBatches],Tx_label_set[0:nBatches],validation_data=(Rx_rand_set[nBatches: ],Tx_label_set[nBatches: ]),
epochs=1000,verbose=1,callbacks=[model_Eq_ckpt])
Model_Eq = tf.keras.models.load_model('./models/Model_Eq_bit'+str(bitlen)+'bits.hdf5')
Rx_Eq = Model_Eq(Rx_MB[np.newaxis,:])
print('[DEBUG] Equalizer retrained, loss: %f,'%Loss.history['loss'][-1], 'output shape:',Rx_Eq.shape)
fig02, ax02 = plt.subplots()
ax02.plot(np.arange(0,len(Loss.history['loss'])),np.log10(Loss.history['loss']),'r',label="Train loss")
ax02.plot(np.arange(0,len(Loss.history['val_loss'])),np.log10(Loss.history['val_loss']),'b',label="Val loss")
ax02.legend(loc='upper right')
ax02.set_xlabel('Epoch')
ax02.set_ylabel('log10(Loss)')
修改日志输出,每100次输出一次
最新发布