模型数据的保存和读取

本文介绍了如何将模型数据保存到本地,以便在训练过程中分段进行或在训练完成后重复利用。通过代码展示了模型参数的保存和读取,以及如何在重启后加载模型进行测试。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1,基本内容
目的是将模型数据以文件的形式保存到本地。
使用神经网络模型进行大数据量和复杂模型训练时,训练时间可能会持续增加,此时为避免训练过程出现不可逆的影响,并验证训练效果,可以考虑分段进行,将训练数据模型保存,然后在继续训练时重新读取; 此外,模型训练完毕,获取一个性能良好的模型后,可以保存以备重复利用。
2,参数保存和读取代码

import tensorflow as tf
#随机初始化两个变量
v1 = tf.Variable(tf.random_normal([1,2]), name="v1")#矩阵大小为[1,2]
v2 = tf.Variable(tf.random_normal([2,4]), name="v2")#矩阵大小为[2,4]
init_op = tf.global_variables_initializer()
saver = tf.train.Saver()#定义该类的一个对象
with tf.Session() as sess:
    sess.run(init_op)
    print ("V1:",sess.run(v1))  
    print ("V2:",sess.run(v2))
    saver_path = saver.save(sess, "Save/model.ckpt")#保存sess计算域中所有的参数值
    print ("Model saved")
    saver.restore(sess, "Save/model.ckpt")#读取保存的文件
    print ("V1_1:",sess.run(v1))  
    print ("V2_1:",sess.run(v2))
    print ("Model restored")

运行结果:
在这里插入图片描述
在这里插入图片描述
2,网络模型的保存与读取代码

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('data/', one_hot=True)
trainimg   = mnist.trai
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值