使用Tensorflow框架完美保存并实现调用训练好的模型(是模型不是参数哦、全网首篇)

本文详细介绍了如何使用Tensorflow的tf.saved_model.builder保存和调用训练好的CNN模型,包括模型的保存、加载网络和关键节点的加载。通过实例展示了如何找到并加载输入和输出节点,确保模型的正确使用。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

众所周知,Tensorflow框架在机器学习、深度学习领域有着广泛的应用。博主最近就Tensorflow框架进行了简单的学习,当我完成简单的回归模型以及CNN的训练后,一个想法油然而生,怎样将训练好的模存保存呢?保存之后又该怎么调用呢?----------于是,我查阅了很多相关资料,经过很多的测试,终于…成功了!现将如何保存和调用训练好的模型总结如下(提一下,本文以CNN为例,数据集是众所周知的MNIST):

1.模型的保存

模型的保存通常有两种方法,一种是Tensorflow老版本中的t.trans.saver,另一种是新版本中的tf.saved_model.builder,博主使用的是后者,所以对于第一种方法也就一带而过。

1.1 tf.train.saver这种方法是使用较多的,但是呢,这种方法更多地是用来保存变量(变量)的,所以博主也就没有了解太多,下文po一段典型代码:

import tensorflow as tf 
... 
#在这里构建网络
... 
#开始保存模型
与tf.Session()作为sess:
	sess.run(tf.global_variables_initializer())#一定要先初始化整个流
	#在这里训练网络
	... 
	#保存参数
	saver = tf.train.Saver()
		saver.save(sess,PATH)#PATH就是要保存的路径	

1.2 tf.saved_model.builder

下文只放出保存模型的关键代码,至于训练的过程,会在后续一并放出,整体分析。

将tensorflow import 为tf
...
#构建网络
...
用tf.Session()作为sess:
	sess.run(tf.global_variables_initializer())#一定要先初始化整个流
	#在这里训练网络
	...
	#保存参数
	builder = tf.saved_model.builder.SaveModelBuilder(PATH)#PATH是保存路径
	builder.add_meta_graph_and_variables(sess,[tf.saved_model.tag_constants.TRAINING])#保存整张网络及其变量,这种方法是可以保存多张网络的,在此不作介绍,可自行了解
	builder.save()#完成保存

2.模型的调用

以上就简单介绍了两种保存模型的方法,接下来介绍一下读取模型的方法,只讲基于tf.saved_model.builder方法的读取模型了>划重点!完整调用并使用一个模型,分为加载网络和加载关键节点两个部分。幸运的是,使用tf.saved_model.builder保存和调用模型,这一切都变得非常简单~~~

2.1加载网络废话不多说,直接po出代码:

将tensorflow import 为tf 
用tf.Session(graph = tf.Graph())作为sess:
	tf.saved_model.loader.load(sess,[tf.saved_model.tag_constants.TRAINING],PATH)#PATH还是路径
	... 
	...

emmmm …只要通过这么几句代码,就可以完成模型的调用了,但是呢,想要真正地使用这个网络,关键点在于加载网络中的关键节点,如:输入,输出

2.2加载张量(节点,Tensor,变量)神经元,可以是加载张量,也可以是加载节点,变量,代码也非常简单,在上节的基础上,加载张量的代码就只有一句话:

#以加载输入变量为例吧
input_x = sess.graph.get_tensor_by_name('input / input_x:0')#传入的参数会在后文讲到的

3.使用训练好的网络

当把输入的变量以及输出的变量都加载后,就可以使用网络了,假设输出变量为输出,输入变量为X和Y,那么使用这个网络的代码就可以表示如下:

sess.run(output,{
   
   x:_x,y:_y})#大括号的内容就是feed_dict了,使用方法类似于占位符

简要的思路

就结合我的代码来具体分析一下整个过程吧,先说一下简要的思路:

  • 训练网络,并构建TensorBoard中的图形图,用于确定网络的关键节点,保存整个网络;
  • 通过TensorBoard查找整个的数据流向,确定输入节点,输出节点;

以下是CNN网络的训练和保存部分:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_dat
### 大模型课程资料与全栈项目实践方案 #### 一、大模型学习资源概述 当前,随着人工智能技术的发展,大模型(Large Language Models, LLMs)已经成为热门领域之一。为了系统化地掌握大模型的知识技术,可以从以下几个方面入手:获取详尽的学习资料[^1],参与专业的培训课程[^2],结合实际项目进行动手实践[^3]。 以下是针对大模型学习的一些建议推荐资源: - **学习路线图**:一份清晰的大模型学习路线图可以帮助初学者快速入门,逐步深入理解复杂的技术细节。 - **商业落地方案**:了解如何将大模型应用于实际场景中至关重要。已有超过百套的商业化落地方案可供参考,这些方案能够帮助开发者更好地规划自己的项目方向。 - **视频教程与文档阅读**:观看高质量的教学视频以及研读经典书籍都是不可或缺的部分。目前存在大量免费或付费形式的大模型相关视频教程及PDF版本教材供选择。 #### 二、具体课程内容解析 对于想要成为AI大模型全栈工程师的人来说,“某乎AI大模型全栈工程师”系列课程是一个不错的选择。该课程具有以下特点: - 提供了从基础知识到高级技巧全覆盖的内容体系; - 注重实战操作能力培养,在多个模块里安排有具体的coding练习环节; - 设置专门章节讲解如何制定科学合理的项目计划书; - 定期邀请业内专家分享最新趋势见解; 此外,还有其他类似的在线教育平台也开设有关于构建端到端解决方案方面的专项训练营[^4]。 #### 三、全栈工程项目实例说明 完成理论知识积累之后,则需要投入到真实的工程环境中锻炼自己解决问题的能力。下面列举几个典型的全栈型任务作为例子: 1. 数据预处理阶段涉及爬取网络公开数据源、标注样本标签等工作流程; 2. 利用PyTorch/TensorFlow框架搭建神经网络结构参数调优; 3. 部署经过验证有效的预测服务接口至云端环境以便外部访问调用. 以上每一步都需要紧密配合才能达成最终目标——即创造出既高效又可靠的智能化工具服务于社会大众的需求. ```python import torch from transformers import AutoTokenizer, AutoModelForCausalLM tokenizer = AutoTokenizer.from_pretrained("chatglm2") model = AutoModelForCausalLM.from_pretrained("chatglm2") def fine_tune_model(data): optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5) for epoch in range(epochs): model.train() total_loss = 0 for batch in data: inputs = tokenizer(batch['text'], return_tensors="pt", truncation=True, padding=True).to('cuda') outputs = model(**inputs, labels=inputs.input_ids) loss = outputs.loss optimizer.zero_grad() loss.backward() optimizer.step() fine_tune_model(training_data) ``` 此代码片段展示了如何利用Hugging Face库加载预训练好的ChatGLM2模型对其进行微调的过程。 ---
评论 17
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值