Tensorflow代码学习-8-1saver_save

该博客介绍了如何使用TensorFlow加载MNIST数据集,并构建一个简单的神经网络进行手写数字识别。通过梯度下降法优化损失函数,逐步提高测试集上的识别准确性。在训练结束后,使用Saver保存模型。

神经网络saver_save

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data   #手写数字相关的数据包
# 载入数据集
mnist = input_data.read_data_sets("MNIST_data",one_hot=True)    #载入数据,{数据集包路径,把标签转化为只有0和1的形式}

#定义变量,即每个批次的大小
batch_size = 100    #一次放100章图片进去
n_batch = mnist.train.num_examples // batch_size   #计算一共有多少个批次;训练集数量(整除)一个批次大小

#定义两个placeholder
x = tf.placeholder(tf.float32,[None,784])    #[行不确定,列为784]
y = tf.placeholder(tf.float32,[None,10])    #数字为0-9,则为10

#创建简单的神经网络
W = tf.Variable(tf.zeros([784,10]))   #权重
b = tf.Variable(tf.zeros([10]))     #偏置
prediction = tf.nn.softmax(tf.matmul(x,W)+b)    #预测

#定义二次代价函数
#loss = tf.reduce_mean(tf.square(y-prediction))
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=prediction))
#使用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)

#初始化变量
init = tf.global_variables_initializer()

#准确数,结果存放在一个布尔型列表中
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))   #比较两个参数大小是否相同,同则返回为true,不同则返回为false;argmax():返回张量中最大的值所在的位置

#求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))   #cast():将布尔型转换为32位的浮点型;(比方说9个T和1个F,则为9个1,1个0,即准确率为90%)

saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(init)
    for epoch in range(11):
        for batch in range(n_batch):
            batch_xs,batch_ys = mnist.train.next_batch(batch_size)
            sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys})
            
        acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
        print("Iter" + str(epoch) + ",Testing Accuracy" + str(acc))
    #保存模型
    saver.save(sess,'net/my_net.ckpt')

Iter0,Testing Accuracy0.8256
Iter1,Testing Accuracy0.8897
Iter2,Testing Accuracy0.9001
Iter3,Testing Accuracy0.9056
Iter4,Testing Accuracy0.9084
Iter5,Testing Accuracy0.9095
Iter6,Testing Accuracy0.9119
Iter7,Testing Accuracy0.9141
Iter8,Testing Accuracy0.9154
Iter9,Testing Accuracy0.9158
Iter10,Testing Accuracy0.9173

点赞,关注,收藏👍,➕,👀点赞,关注,收藏👍,➕,👀点赞,关注,收藏👍,➕,👀
😘😘😘😘😘😘😘😘😘😘😘😘😘😘😘😘😘😘😘😘😘😘😘😘😘😘😘😘😘😘😘😘
💪💪💪💪💪💪💪💪💪💪💪💪💪💪💪💪💪💪💪💪💪💪💪💪💪💪💪💪💪💪💪💪

在这里插入图片描述在这里插入图片描述

欢迎大噶关注我的微信公众号

<think>好的,我现在需要解决用户提到的两个问题:TensorFlow 2.0中contrib模块不可用的问题,以及ipykernel_launcher.py报错unrecognized arguments的问题。先分开来看这两个问题,再逐个分析解决。 首先是TensorFlow 2.0的contrib模块替代方案。根据用户提供的引用内容,特别是引用[2]和引用[3],了解到TensorFlow 2.0移除了contrib模块,将其中的功能分散到其他核心模块或需要单独安装的库中,比如tensorflow_addons。例如,原来在contrib中的CRF功能现在需要安装tensorflow_addons并使用其中的模块。此外,像rnn_cell这样的模块在TensorFlow 2.0中可能需要从tensorflow.python.ops导入。用户提到的HParams的问题,引用[4]中说明tf.contrib.training.HParams在2.0中已失效,可能需要使用其他方法,比如将参数封装到Python类中,或者使用第三方库如hparams。 然后是ipykernel_launcher.py的unrecognized arguments错误,错误信息中提到了参数--f=...。这可能是因为在运行Jupyter notebook或相关环境时,传递了不被识别的参数。需要检查代码中是否有误用的参数,特别是argparse解析的参数是否存在冲突。例如,用户可能在代码中定义了一个参数解析器,但实际运行时传递了未被解析的参数,如--f,这会导致错误。解决办法是检查代码中argparse部分,添加对应的参数定义,或者检查运行命令是否有误。 接下来,针对每个问题详细思考解决方案: 对于TensorFlow contrib模块的替代方案: 1. **检查功能迁移路径**:根据官方文档或社区资源,确定原contrib模块中的功能在TensorFlow 2.0中的位置。例如,rnn_cell可能迁移到tensorflow.python.ops.rnn_cell,或者需要调整导入方式。 2. **使用tensorflow_addons**:对于某些高级功能,如CRF层,安装并使用tensorflow_addons库,该库提供了许多contrib模块中的功能。 3. **修改导入语句**:例如,将`from tensorflow.contrib import rnn`改为`from tensorflow.python.ops import rnn_cell`或者使用兼容性模块`tf.compat.v1`。 4. **重构代码使用核心API**:例如,替换tf.placeholder为tf.Variable或使用tf.data API,替换tf.train.Saver为tf.compat.v1.train.Saver等。 5. **参数管理替代方案**:如HParams,可以使用自定义类或第三方库如gin-config或hparams。 对于ipykernel_launcher.py的unrecognized arguments错误: 1. **检查代码中的参数解析**:查看是否有使用argparse或类似库定义的参数解析,是否遗漏了某些参数的声明,尤其是--f参数是否被正确解析。 2. **检查Jupyter环境参数**:可能在启动kernel时传递了不必要的参数,需要确认这些参数是否需要,或者是否应该被忽略。 3. **添加忽略未知参数的选项**:在argparse中设置`parse_known_args()`而不是`parse_args()`,这样未识别的参数不会导致错误,但需要注意潜在影响。 4. **检查运行命令或配置**:确认是否有脚本或配置文件错误地添加了--f参数,可能需要修改启动配置或环境变量。 综合用户的问题和提供的引用,特别是引用[3]提到的解决方案,可能需要升级代码TensorFlow 2.0的API,并使用替代库。同时,对于参数错误,需调整参数解析部分。 最后,生成相关问题,帮助用户深入理解问题,如替代方案的具体步骤、如何升级代码、如何处理参数错误的具体方法等。</think>### TensorFlow 2.0 contrib模块替代方案与参数错误解决方案 #### 一、TensorFlow 2.0 contrib模块替代方案 TensorFlow 2.0移除了`contrib`模块,需根据具体功能迁移到以下替代方案: 1. **核心API迁移** - **RNN模块**:原`tf.contrib.rnn`中的单元(如`BasicLSTMCell`)需改为: ```python from tensorflow.python.ops import rnn_cell cell = rnn_cell.BasicLSTMCell(num_units) ``` - **模型保存**:`tf.train.Saver`改为兼容模式: ```python saver = tf.compat.v1.train.Saver() ``` 2. **使用tensorflow_addons** - 安装附加库: ```bash pip install tensorflow-addons ``` - **CRF模块**替代: ```python from tensorflow_addons.text import crf_log_likelihood ``` 3. **HParams参数管理** - 使用自定义类或第三方库: ```python class HParams: def __init__(self, **kwargs): self.__dict__.update(kwargs) hparams = HParams(learning_rate=0.001, batch_size=32) ``` 4. **兼容性模块** - 启用TensorFlow 1.x行为: ```python import tensorflow.compat.v1 as tf tf.disable_v2_behavior() ``` #### 二、ipykernel_launcher.py参数错误解决 错误`unrecognized arguments --f=...`通常因参数解析冲突引起,解决方法如下: 1. **检查代码中的argparse定义** - 若使用`argparse`,需明确定义所有参数: ```python import argparse parser = argparse.ArgumentParser() parser.add_argument("--f", help="JSON file path") args = parser.parse_args() ``` 2. **忽略未知参数** - 使用`parse_known_args()`允许未定义参数: ```python args, unknown = parser.parse_known_args() ``` 3. **排查Jupyter配置** - 检查Jupyter内核配置文件,移除多余的`--f`参数引用。 #### 三、完整代码示例(以MNIST数据读取为例) ```python import tensorflow.compat.v1 as tf tf.disable_v2_behavior() # 替代原contrib.datasets的MNIST读取方式 (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() ``` #### 四、引用说明 - TensorFlow 2.0移除了contrib模块,需依赖核心API或附加库[^3]。 - `tf.contrib.training.HParams`可通过自定义类实现[^4]。 - 参数解析错误需检查argparse定义与运行参数[^2]。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值