tensorflow2.0 模型的保存
- 通过ModelCheckpoint保存模型,该方式只保存网络达到某一种要求状态下的参数(比如训练模型达到最好时的参数),而不管其他的状态。
- 通过saved_model保存,是模型的一种保存格式,可以把这个保存的模型直接交给用户来部署,而不需要提供源代码或相关的信息,保存的模型就包含了这些信息。
- 通过model.save保存模型,是把模型所有的状态都保存起来,简单粗暴。
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import sklearn
import pandas as pd
import os
import sys
import time
import tensorflow as tf
from tensorflow import keras
print(tf.__version__)
2.0.0
fashion_mnist = keras.datasets.fashion_mnist
(x_train_all, y_train_all), (x_test, y_test) = fashion_mnist.load_data()
x_valid, x_train = x_train_all[:5000], x_train_all[5000:]
y_valid, y_train = y_train_all[:5000], y_train_all[5000:]
print(x_valid.shape, y_valid.shape)
print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)
model = keras.models.Sequential([
keras.layers.Flatten(input_shape=[28, 28]),
keras.layers.Dense(300, activation='relu'),
keras.layers.Dense(100, activation='relu'),
keras.layers.Dense(10, activation='softmax')
])
model.compile(loss="sparse_categorical_crossentropy",
optimizer = "sgd",
metrics = ["accuracy"])
1. ModelCheckpoint保存模型
logdir = './graph_def_and_weights'
if not os.path.exists(logdir):
os.mkdir(logdir)
output_model_file = os.path.join(logdir,
"fashion_mnist_weights.h5")
callbacks = [
keras.callbacks.TensorBoard(logdir),
keras.callbacks.ModelCheckpoint(output_model_file, #保存模型
save_best_only = True,
save_weights_only = False),
keras.callbacks.EarlyStopping(patience=5, min_delta=1e-3),
]
history = model.fit(x_train_scaled, y_train, epochs=10,
validation_data=(x_valid_scaled, y_valid),
callbacks = callbacks)
model.evaluate(x_test_scaled, y_test)
loaded_model = keras.models.load_model(output_model_file) #载入模型
loaded_model.evaluate(x_test_scaled, y_test)
2. saved_model保存模型
tf.saved_model.save(model, "./keras_saved_graph")#保存模型
!saved_model_cli show --dir ./keras_saved_graph --all #查看模型,签名
loaded_saved_model = tf.saved_model.load('./keras_saved_graph') #载入模型
print(list(loaded_saved_model.signatures.keys())) #打印签名
inference = loaded_saved_model.signatures['serving_default'] #获取函数句柄
print(inference)
print(inference.structured_outputs)
results = inference(tf.constant(x_test_scaled[0:1])) #测试
print(results['dense_2'])
3. model.save保存模型
model.save('save_model.h5')
print('saved total model.')
print('loaded model from file.')
loaded_model = tf.keras.models.load_model('save_model.h5', compile=False) # 不需要重新创建网络
loaded_model.evaluate(x_test_scaled, y_test)
参数的保存和载入
model.save_weights(os.path.join(logdir, "fashion_mnist_weights_2.h5")) #保存参数
logdir = './graph_def_and_weights'
if not os.path.exists(logdir):
os.mkdir(logdir)
output_model_file = os.path.join(logdir,
"fashion_mnist_weights_2.h5")
'''
# 载入参数前要先建立模型结构
model = keras.models.Sequential([
keras.layers.Flatten(input_shape=[28, 28]),
keras.layers.Dense(300, activation='relu'),
keras.layers.Dense(100, activation='relu'),
keras.layers.Dense(10, activation='softmax')
])
model.compile(loss="sparse_categorical_crossentropy",
optimizer = "sgd",
metrics = ["accuracy"])
'''
model.load_weights(output_model_file) #载入参数
model.evaluate(x_test_scaled, y_test)
本文详细介绍了在TensorFlow 2.0中如何使用ModelCheckpoint、saved_model和model.save三种方法保存模型,包括参数保存、完整模型保存及模型的加载过程,适用于不同场景下的模型部署和应用。
884

被折叠的 条评论
为什么被折叠?



