模型的读取
我们经常需要吧写好的模型进行保存于读取。
在模型的读取中,tensorflow给出了利用load_weights(路径文件名)的方式进行读取.因此我们在模型的读取中经常使用一下的方式:
checkpoint_save_path = './checkpoint/mnist.ckpt'
if os.path.exists(checkpoint_save_path + '.index'):
print('-------------load the model--------------')
model.load_weights(checkpoint_save_path)
由于在生成ckpt文件的时候,会同步的生存索引表,因此我们可以判断是否已经有了索引表,就可以判断是否保存了模型的参数。
模型的保存
保存模型的参数可以使用tensorflow给的回调函数,直接保存训练出来的模型参数。需要告知模型的保存路径,是否只保存模型参数,是否只保存模型的最优结果。
tf.keras.callbacks.ModelCheckpoint(
filepath=路径文件名,
save_weights_only=True/False,
save_best_only=True/False
)
history = model.fit(callbacks=[cp_callback])
在实际应用中保存模型我们经常这样做:
cp_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_save_path,
save_weights_only=Yrue,\
save_best_only=True
)
history = model.fit(x_train,y_train,batch_size=32, epochs=5,
validation_data=(x_test,y_test), validation_freq=1,
callbacks=[cp_callback]
)
保存模型的好处,就是通过保存最好的模型在最好的模型基础上进行训练的时候,模型的识别结果时在最好的模型基础上继续提升。
案例程序
import tensorflow as tf
import os
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss=tf.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=['sparse_categorical_accuracy'])
checkpoint_save_path = './checkpoint/mnist.ckpt'
if os.path.exists(checkpoint_save_path + '.index'):
print('------------------------load the model---------------------')
model.load_weights(checkpoint_save_path)
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
save_weights_only=True,
save_best_only=True)
history = model.fit(x_train, y_train, batch_size=10, epochs=5,
validation_data=(x_test, y_test),
validation_freq=1,
callbacks=[cp_callback])
model.summary()
训练结果如下所示:
然后我们再训练一次就会加载保存好的最好模型开始训练。训练结果如下所示: