踩坑
不同h5文件来保存,不管是用save和load,虽然写法比较简单,但是经常遇到各种各样的毛病。还有save_weights和load_weights,尽管能用,但是仍然存在问题
最后发现一种最好用的方法
保存
Model1 = Net1()
scce = tf.keras.losses.SparseCategoricalCrossentropy()
adam = Adam(learning_rate=0.001, beta_1=0.9, beta_2=0.99, epsilon=10e-8, amsgrad=False, name="Aadm")
Model1.compile(optimizer=adam,loss=scce, metrics=['accuracy'])
checkpoint_save_path = './/model1//Model1.ckpt'
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path, save_weights_only=True)
Model1.fit(k_x_train,k_y_train,batch_size=64,epochs=30,callbacks=cp_callback)
predictions_nn1 = Model1.predict(X_test)
加载
model1 = net.Net1()
scce = tf.keras.losses.SparseCategoricalCrossentropy()
adam = Adam(learning_rate=0.001, beta_1=0.9, beta_2=0.99, epsilon=10e-8, amsgrad=False, name="Aadm")
model1.compile(optimizer=adam,loss=scce, metrics=['accuracy'])
checkpoint_save_path = './/Model123.ckpt'
model1.load_weights(checkpoint_save_path)
predictions_nn = model1.predict(X_test)
本文通过实践经验分享了使用H5文件进行模型保存及加载过程中遇到的问题,并提供了一种有效的解决方案。通过具体实例展示了如何利用`tf.keras.callbacks.ModelCheckpoint`进行模型权重的保存与加载,确保了模型训练结果能够被稳定复现。
8176

被折叠的 条评论
为什么被折叠?



