断点续续-神经网络模型的存取

模型的读取

我们经常需要吧写好的模型进行保存于读取。

在模型的读取中,tensorflow给出了利用load_weights(路径文件名)的方式进行读取.因此我们在模型的读取中经常使用一下的方式:

checkpoint_save_path = './checkpoint/mnist.ckpt'
if os.path.exists(checkpoint_save_path + '.index'): 
    print('-------------load the model--------------')
    model.load_weights(checkpoint_save_path)   
    

由于在生成ckpt文件的时候,会同步的生存索引表,因此我们可以判断是否已经有了索引表,就可以判断是否保存了模型的参数。

模型的保存

保存模型的参数可以使用tensorflow给的回调函数,直接保存训练出来的模型参数。需要告知模型的保存路径,是否只保存模型参数,是否只保存模型的最优结果。

tf.keras.callbacks.ModelCheckpoint(
    filepath=路径文件名,
    save_weights_only=True/False,
    save_best_only=True/False
)

history = model.fit(callbacks=[cp_callback])

在实际应用中保存模型我们经常这样做:

cp_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_save_path,
    save_weights_only=Yrue,\
    save_best_only=True
)
history = model.fit(x_train,y_train,batch_size=32, epochs=5,
    validation_data=(x_test,y_test), validation_freq=1,
    callbacks=[cp_callback]    
)

保存模型的好处,就是通过保存最好的模型在最好的模型基础上进行训练的时候,模型的识别结果时在最好的模型基础上继续提升。

案例程序

import tensorflow as tf
import os

mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
              loss=tf.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])

checkpoint_save_path = './checkpoint/mnist.ckpt'
if os.path.exists(checkpoint_save_path + '.index'):
    print('------------------------load the model---------------------')
    model.load_weights(checkpoint_save_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
                                                 save_weights_only=True,
                                                 save_best_only=True)

history = model.fit(x_train, y_train, batch_size=10, epochs=5,
                    validation_data=(x_test, y_test),
                    validation_freq=1,
                    callbacks=[cp_callback])
model.summary()

训练结果如下所示: 

 

 然后我们再训练一次就会加载保存好的最好模型开始训练。训练结果如下所示:

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

AI炮灰

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值