tensorflow中模型的存储与加载

本文介绍了使用TensorFlow进行模型存储和加载的两种主要方法:生成检查点文件(.ckpt)用于保存权重和变量,以及生成图协议文件(.pb)用于保存图形结构。并通过一个MNIST数据集的例子,展示了如何在实际代码中应用这两种方法。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

一、两种存储和加载模型的方法

tensorflow的API提供了两种方式来存储和加载模型
1、生成检查点文件(checkpoint file),扩展名是.ckpt,通过在tf.train.Saver对象上调用Saver.save()生成,它包括权重和其它在程序中定义的变量,不包括图结构。如果需要在另一个程序中使用,需要重新构建图结构,并告诉tensorflow如何处理这些权重。
2、生成图协议文件(graph proto file),这是一个二进制文件,扩展名一般是.pb,用tf.train.write_graph()保存,只包含图形结构,不包含权重,然后使用tf.import_graph_def()来加载图形。

二、代码

import tensorflow as tf
import numpy as np
import random
import matplotlib.pyplot as plt
 ​
### step1 开启eager
# tf.enable_eager_execution()
​
### step2 MNIST数据下载的地址
# 数据的下载地址 https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
​
### step3 从MNIST数据文件导入MNIST数据
with np.load("./estimator/0808_estimator/MNIST_data/mnist.npz") as f:
    x_train, y_train = f['x_train'], f['y_train']
    x_train = x_train.reshape([-1, 28,28,1])
    y_train = y_train.reshape([-1])
    x_test, y_test = f['x_test'], f['y_test']
    x_test = x_test.reshape([-1, 28,28,1])
    y_test = y_test.reshape([-1])
​
### step4 生成数据集
# train_dataset = np.concatenate([x_train, y_train], axis=-1)
# test_dataset  = np.concatenate([y_test,  y_test],  axis=-1)
​
​
X = tf.placeholder(tf.float32, [None, 28,28,1])
y = tf.placeholder(tf.int64,   [None,])
​
# step5 Build the model
mnist_model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(16,[3,3], activation='relu'),
    tf.keras.layers.Conv2D(16,[3,3], activation='relu'),
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dense(10)
])
​
### step6 训练
logits = mnist_model(X, training=True)
y_pred = tf.argmax(logits, axis=-1)
​
def accuracy(y_pred, y):
    return np.mean(np.equal(y_pred, y))
​
cost = tf.losses.sparse_softmax_cross_entropy(labels=y, logits=logits)
train_op = tf.train.AdamOptimizer(0.001).minimize(cost)
### 定义global_step,trainable是False
global_step = tf.Variable(0, name="global_step", trainable=False)
​
### tf.train.Saver()要放在所有变量定义完之后
### 参数max_to_keep:保存多少个最新的checkpoint文件,默认为5,即保存最近五个checkpoint文件;
saver = tf.train.Saver()
loss_all = []
with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    start = global_step.eval()
    print("start from: ", start)
    Length = y_train.shape[0]
    batch_times = int(Length/128)+1
    i = 0
    for _ in range(2):
        np.random.seed(520)
        # np.random.shuffle(train_dataset)
        for batch in range(batch_times):
            start = batch*128
            end = min(Length, start+128)
            image = x_train[start:end,...]
            label = y_train[start:end]
            ### sess.run()中运行的变脸要放在一起,用[]框起来,feed_dict中的值要是numpy类型,不能是tf.Tensor类型
            _, loss, y_ = sess.run([train_op, cost, y_pred], feed_dict={X:image, y:label})
            ### 要实现global_step的增加,必须要加上eval(),否则增加就不会成功
            global_step.assign(i).eval()
            if global_step.eval()%50==0:
                print(global_step.eval(), "  : ", loss)
                print(accuracy(y_,label))
                saver.save(sess, save_path="./model/mode.ckpt", global_step=global_step)
            if global_step.eval()%500==0:
                y_ = sess.run(y_pred, feed_dict={X: x_train, y: y_train})
                print(accuracy(y_,y_train))
                y_ = sess.run(y_pred, feed_dict={X:x_test, y:y_test})
                print(accuracy(y_,y_test))
            i += 1
            loss_all.append(loss)
            
### step7 加载模型
saver = tf.train.Saver()
with tf.Session() as sess:
	 ### 必须是先初始化再加载模型
    sess.run(tf.initialize_all_variables())
    ckpt = tf.train.get_checkpoint_state("./model")
    if ckpt and ckpt.model_checkpoint_path:
        print(ckpt.model_checkpoint_path)
        ### 其实加载模型就这一句话
        saver.restore(sess,ckpt.model_checkpoint_path)
    loss, y_ = sess.run([cost, y_pred], feed_dict={X:x_test, y:y_test})
    print(accuracy(y_,y_test))
    print(loss)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值