DL实战(1):tensorflow在mnist上实现siamese net

本文介绍了在TensorFlow中实现Siamese网络以解决类别多但样本少的问题。在实践中遇到并解决了包括tf.sub和tf.mul属性错误、模型保存和加载路径问题、raw_input未定义、读取checkpoint文件以及embed.txt数据错误等五个问题,并提供了相应的解决方案。文章还简要解释了部分代码功能。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

一、写在前面

参考:https://www.cnblogs.com/bentuwuying/p/8186364.html

简介:全连接孪生网络(siamese network)是一种相似性度量方法,适用于类别数目多但是每类的样本数少的分类问题。

代码:https://github.com/Shicoder/DeepLearning_Demo/tree/master/siamese_tf_mnist

二、测试代码

按照github代码运行python run.py后报错:

【报错1】

AttributeError: module ‘tensorflow‘ has no attribute ‘sub‘

AttributeError: module ‘tensorflow‘ has no attribute ‘mul‘

【解决1】

将代码inference.py中的tf.sub替换为tf.subtract,tf.mul替换为tf.multiply

----------------------------------------------------------------------------------------------------------------

【报错2】

NotFoundError (see above for traceback): Failed to create a directory: ; No such file or directory

         [[Node: save/SaveV2 = SaveV2[dtypes=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/SaveV2/tensor_names, save/SaveV2/shape_and_slices, siamese/fc1W, siamese/fc1b, siamese/fc2W, siamese/fc2b, siamese/fc3W, siamese/fc3b)]]

【解决2】

将代码run.py中的saver.save(sess, 'model.ckpt')以及saver.restore(sess, 'model.ckpt')改为saver.save(sess, './model.ckpt')以及saver.restore(sess, './model.ckpt')

----------------------------------------------------------------------------------------------------------------

【报错3】

name 'raw_input' is not defined

【解决3】

将代码中的raw_input改为input即可。

----------------------------------------------------------------------------------------------------------------

【报错4】

读取model.ckpt失败

【解决4】

将代码中的os.path.isfile(model_ckpt)改为判断当前路径下是否存在checkpoint这个文件os.path.exists('E:\pyProject\DeepLearning_Demo-master\siamese_tf_mnist\checkpoint'): 

----------------------------------------------------------------------------------------------------------------

【报错5】

读取embed.txt数据错误

【解决5】

embed.tofile("embed.txt")#保存为二进制文件,且不能保存当前数据的行列信息

 

用np.fromfile读取数据需要手动指定dtype,如果指定的格式与保存时的不一致,则读出来的就是错误的数据。

print(embed.shape)

print(embed.dtype)

embed=np.fromfile('embed.txt',dtype=np.float32) #读取数据

注意到读出来的数据是一维数组,需要利用np.reshape方法重新指定维数

 

三、部分代码解读

 

from __future__ import absolute_import #使用py3的绝对引入
from __future__ import division #使用py3的精确除法
from __future__ import print_function #使用py3的print()功能函数

__future__模块,把下一个新版本的特性导入到当前版本,从而能够在当前旧版本中测试一些新版本的特性。

 

 

四、实验结果

 

 

 

 

 

 

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值