TensorFlow之模型保存与加载

本文介绍了如何在Keras中使用检查点保存训练过程,以便于恢复中断的训练或共享模型。通过创建检查点,可以保存模型的权重,当训练中断时从中恢复。模型的完整保存包括架构、权重和训练设置,可以是SavedModel或HDF5格式,允许直接加载和继续训练,或进行测试和预测。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

模型在训练过程中或者在训练之后,模型的执行过程能被保存,也就是,模型能从暂停中恢复以免训练的时间过长。因此,被保存的模型可以被共享,其他人可以重新构建相同的模型。被保存的模型以如下的两种方式进行共享:

  •  创建模型的代码

  • 被训练模型的权重或者参数

安装类库

如上所示,安装保存模型的类库、工具集,并导入其支持的函数,其中,以HDF5的文件格式保存模型。

加载数据样本

如上所示,加载mnist训练数据集,该数据集是流行服饰的图片数据集,其中,train_images是训练的图片数据集,train_labels是训练的标签数据集,分别使用mnist数据集的前1000条记录。

定义模型

如上所示,定义创建序列模型的函数、创建模型,其中,该模型仅定义一个全连接层、每层512个单元、使用删除正则化机制、使用Adam优化器优化模型。

保存检查点

检查点是用于保证训练过程的可用性,当训练过程被中断,设置检查点可以保存中断之前的训练数据,例如,保存了训练过程的权重或者参数,用户可以从检查点中恢复训练。Keras技术框架提供tf.keras.callbacks.ModelCheckpoint工具类用于保存训练过程的检查点,用户使用检查点可以持续地保存训练过程或者训练结束的训练数据。

检查点回调(callback)

如上所示,以回调的方式创建了一个模型的检查点,其中,checkpoint_path是保存检查点对应的文件路径,cp_callback是创建一个检查点函数,callbacks=[cp_callback]是为模型设置了训练过程中回调检查点函数、持续地保存训练过程中的检查点,该函数是在每次训练迭代结束的时候被回调,训练多次迭代则多次回调。

如上所示,显示训练结束之后,检查点保存路径下的文件信息。

如上所示,先创建一个模型model,在训练之前对模型执行测试评估,显示其准确度只有7%,使用模型的load_weights函数加载之前训练的检查点,创建一个权重都相同的模型,然后执行测试评估,显示其准确度达到86.60%(重用已训练完成的模型,包括权重)。

检查点选项设置

如上所示,用户可以根据实际情况设置检查点的参数选项,例如,checkpoint_path可以设置每次训练迭代保存的检查点的名称,save_freq设置检查点保存的频率,每多少次训练迭代保存一次检查点,日志显示每5次训练迭代记录一次检查点。

如上所示,显示模型检查点保存的文件列表,以及显示最后一个检查点。

如上所示,使用load_weights函数加载模型的最后一次训练的检查点,其包括该次检查点的权重,然后,执行模型的测试评估,其准确度达到87.30%。

检查点文件属性

检查点文件是二进制格式的字节文件,其中保存了训练过程中学习到的权重信息,其描述如下所示:

  • 一个或者多个分片,每个分片都包含了模型训练所得的权重信息

  • 一个索引文件,对以上的分片文件进行索引

在单节点的机器中训练模型,则生成的其中一个分片的检查点文件的后缀是.data-00000-of-00001。

手动保存权重

如上所示,使用模型的save_weights保存模型的权重,load_weights函数是加载已保存的权重,然后,使用测试数据集对模型执行测试评估,其准确度达到87.30%。

保存完整模型

使用keras技术框架的tf.keras.Model.save函数可以保存完整的模型,该保存方式包括模型的架构、模型权重、模型训练的设置,该保存方式可以让用户直接重用已保存的模型,而不用修改模型的代码,也就是,由于优化器的状态能被恢复,用户可以直接恢复模型到上次训练暂停的时候。该保存方式的包括两种格式,SavedModel以及HDF5,其中,SavedModel是默认的保存格式。

SavedModel格式

如上所示,使用模型的save函数保存完整的已训练完成的模型。

如上所示,显示保存模型的文件列表,然后,使用load_model函数加载已保存的模型,输出模型的汇总信息。

如上所示,重新加载了已训练完成、完整的、相同的模型,然后,使用模型进行测试评估、预测。

HDF5格式

如上所示,创建一个模型、对模型执行训练、以HDF5的格式保存完整的模型。

如上所示,重新加载HDF5格式的已经保存的模型,输出模型汇总信息。

如上所示,使用测试数据集对重新加载的模型执行测试评估,其准确度达到85.90%。

由以上的分析可知,Keras技术框架保存模型的信息如下所示:

  •  模型训练所得的权重值

  • 模型的架构

  • 模型的训练设置,例如,传入到模型compile函数的参数

  • 模型的优化器以及优化器的状态

(未完待续)

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

计算机科技研究员

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值